1
|
|
|
# Licensed under a 3-clause BSD style license - see LICENSE.rst |
2
|
|
|
from collections import OrderedDict |
3
|
|
|
import logging |
4
|
|
|
import numpy as np |
5
|
|
|
import scipy.interpolate |
6
|
|
|
import scipy.ndimage |
7
|
|
|
import scipy.signal |
8
|
|
|
from astropy.table import Table |
9
|
|
|
import astropy.units as u |
10
|
|
|
from astropy.convolution import Tophat2DKernel |
11
|
|
|
from astropy.io import fits |
12
|
|
|
from astropy.nddata import Cutout2D |
13
|
|
|
from gammapy.extern.skimage import block_reduce |
14
|
|
|
from gammapy.utils.interpolation import ScaledRegularGridInterpolator |
15
|
|
|
from gammapy.utils.units import unit_from_fits_image_hdu |
16
|
|
|
from .geom import pix_tuple_to_idx, MapCoord |
17
|
|
|
from .reproject import reproject_car_to_hpx, reproject_car_to_wcs |
18
|
|
|
from .utils import INVALID_INDEX, interp_to_order |
19
|
|
|
from .utils import get_random_state |
20
|
|
|
from gammapy.utils.random import InverseCDFSampler, get_random_state |
21
|
|
|
from .wcs import _check_width |
22
|
|
|
from .wcsmap import WcsGeom, WcsMap |
23
|
|
|
|
24
|
|
|
__all__ = ["WcsNDMap"] |
25
|
|
|
|
26
|
|
|
log = logging.getLogger(__name__) |
27
|
|
|
|
28
|
|
|
|
29
|
|
|
class WcsNDMap(WcsMap): |
30
|
|
|
"""HEALPix map with any number of non-spatial dimensions. |
31
|
|
|
|
32
|
|
|
This class uses an ND numpy array to store map values. For maps with |
33
|
|
|
non-spatial dimensions and variable pixel size it will allocate an |
34
|
|
|
array with dimensions commensurate with the largest image plane. |
35
|
|
|
|
36
|
|
|
Parameters |
37
|
|
|
---------- |
38
|
|
|
geom : `~gammapy.maps.WcsGeom` |
39
|
|
|
WCS geometry object. |
40
|
|
|
data : `~numpy.ndarray` |
41
|
|
|
Data array. If none then an empty array will be allocated. |
42
|
|
|
dtype : str, optional |
43
|
|
|
Data type, default is float32 |
44
|
|
|
meta : `dict` |
45
|
|
|
Dictionary to store meta data. |
46
|
|
|
unit : str or `~astropy.units.Unit` |
47
|
|
|
The map unit |
48
|
|
|
""" |
49
|
|
|
|
50
|
|
|
def __init__(self, geom, data=None, dtype="float32", meta=None, unit=""): |
51
|
|
|
# TODO: Figure out how to mask pixels for integer data types |
52
|
|
|
|
53
|
|
|
data_shape = geom.data_shape |
54
|
|
|
|
55
|
|
|
if data is None: |
56
|
|
|
data = self._make_default_data(geom, data_shape, dtype) |
57
|
|
|
|
58
|
|
|
super().__init__(geom, data, meta, unit) |
59
|
|
|
|
60
|
|
|
@staticmethod |
61
|
|
|
def _make_default_data(geom, shape_np, dtype): |
62
|
|
|
# Check whether corners of each image plane are valid |
63
|
|
|
coords = [] |
64
|
|
|
if not geom.is_regular: |
65
|
|
|
for idx in np.ndindex(geom.shape_axes): |
66
|
|
|
pix = ( |
67
|
|
|
np.array([0.0, float(geom.npix[0][idx] - 1)]), |
68
|
|
|
np.array([0.0, float(geom.npix[1][idx] - 1)]), |
69
|
|
|
) |
70
|
|
|
pix += tuple([np.array(2 * [t]) for t in idx]) |
71
|
|
|
coords += geom.pix_to_coord(pix) |
72
|
|
|
|
73
|
|
|
else: |
74
|
|
|
pix = ( |
75
|
|
|
np.array([0.0, float(geom.npix[0] - 1)]), |
76
|
|
|
np.array([0.0, float(geom.npix[1] - 1)]), |
77
|
|
|
) |
78
|
|
|
pix += tuple([np.array(2 * [0.0]) for i in range(geom.ndim - 2)]) |
79
|
|
|
coords += geom.pix_to_coord(pix) |
80
|
|
|
|
81
|
|
|
if np.all(np.isfinite(np.vstack(coords))): |
82
|
|
|
if geom.is_regular: |
83
|
|
|
data = np.zeros(shape_np, dtype=dtype) |
84
|
|
|
else: |
85
|
|
|
data = np.full(shape_np, np.nan, dtype=dtype) |
86
|
|
|
for idx in np.ndindex(geom.shape_axes): |
87
|
|
|
data[idx, slice(geom.npix[0][idx]), slice(geom.npix[1][idx])] = 0.0 |
|
|
|
|
88
|
|
|
else: |
89
|
|
|
data = np.full(shape_np, np.nan, dtype=dtype) |
90
|
|
|
idx = geom.get_idx() |
91
|
|
|
m = np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0) |
92
|
|
|
data[m] = 0.0 |
93
|
|
|
|
94
|
|
|
return data |
95
|
|
|
|
96
|
|
|
@classmethod |
97
|
|
|
def from_hdu(cls, hdu, hdu_bands=None): |
98
|
|
|
"""Make a WcsNDMap object from a FITS HDU. |
99
|
|
|
|
100
|
|
|
Parameters |
101
|
|
|
---------- |
102
|
|
|
hdu : `~astropy.io.fits.BinTableHDU` or `~astropy.io.fits.ImageHDU` |
103
|
|
|
The map FITS HDU. |
104
|
|
|
hdu_bands : `~astropy.io.fits.BinTableHDU` |
105
|
|
|
The BANDS table HDU. |
106
|
|
|
""" |
107
|
|
|
geom = WcsGeom.from_header(hdu.header, hdu_bands) |
108
|
|
|
shape = tuple([ax.nbin for ax in geom.axes]) |
109
|
|
|
shape_wcs = tuple([np.max(geom.npix[0]), np.max(geom.npix[1])]) |
110
|
|
|
|
111
|
|
|
meta = cls._get_meta_from_header(hdu.header) |
112
|
|
|
unit = unit_from_fits_image_hdu(hdu.header) |
113
|
|
|
map_out = cls(geom, meta=meta, unit=unit) |
114
|
|
|
|
115
|
|
|
# TODO: Should we support extracting slices? |
116
|
|
|
if isinstance(hdu, fits.BinTableHDU): |
117
|
|
|
pix = hdu.data.field("PIX") |
118
|
|
|
pix = np.unravel_index(pix, shape_wcs[::-1]) |
119
|
|
|
vals = hdu.data.field("VALUE") |
120
|
|
|
if "CHANNEL" in hdu.data.columns.names and shape: |
121
|
|
|
chan = hdu.data.field("CHANNEL") |
122
|
|
|
chan = np.unravel_index(chan, shape[::-1]) |
123
|
|
|
idx = chan + pix |
124
|
|
|
else: |
125
|
|
|
idx = pix |
126
|
|
|
|
127
|
|
|
map_out.set_by_idx(idx[::-1], vals) |
128
|
|
|
else: |
129
|
|
|
map_out.data = hdu.data |
130
|
|
|
|
131
|
|
|
return map_out |
132
|
|
|
|
133
|
|
|
def get_by_idx(self, idx): |
134
|
|
|
idx = pix_tuple_to_idx(idx) |
135
|
|
|
return self.data.T[idx] |
136
|
|
|
|
137
|
|
|
def interp_by_coord(self, coords, interp=None, fill_value=None): |
138
|
|
|
|
139
|
|
|
if self.geom.is_regular: |
140
|
|
|
pix = self.geom.coord_to_pix(coords) |
141
|
|
|
return self.interp_by_pix(pix, interp=interp, fill_value=fill_value) |
142
|
|
|
else: |
143
|
|
|
return self._interp_by_coord_griddata(coords, interp=interp) |
144
|
|
|
|
145
|
|
|
def interp_by_pix(self, pix, interp=None, fill_value=None): |
146
|
|
|
"""Interpolate map values at the given pixel coordinates. |
147
|
|
|
""" |
148
|
|
|
if not self.geom.is_regular: |
149
|
|
|
raise ValueError("interp_by_pix only supported for regular geom.") |
150
|
|
|
|
151
|
|
|
order = interp_to_order(interp) |
152
|
|
|
if order == 0 or order == 1: |
153
|
|
|
return self._interp_by_pix_linear_grid( |
154
|
|
|
pix, order=order, fill_value=fill_value |
155
|
|
|
) |
156
|
|
|
elif order == 2 or order == 3: |
157
|
|
|
return self._interp_by_pix_map_coordinates(pix, order=order) |
158
|
|
|
else: |
159
|
|
|
raise ValueError("Invalid interpolation order: {!r}".format(order)) |
160
|
|
|
|
161
|
|
|
def _interp_by_pix_linear_grid(self, pix, order=1, fill_value=None): |
162
|
|
|
# TODO: Cache interpolator |
163
|
|
|
method_lookup = {0: "nearest", 1: "linear"} |
164
|
|
|
try: |
165
|
|
|
method = method_lookup[order] |
166
|
|
|
except KeyError: |
167
|
|
|
raise ValueError("Invalid interpolation order: {!r}".format(order)) |
168
|
|
|
|
169
|
|
|
grid_pix = [np.arange(n, dtype=float) for n in self.data.shape[::-1]] |
170
|
|
|
|
171
|
|
|
if np.any(np.isfinite(self.data)): |
172
|
|
|
data = self.data.copy().T |
173
|
|
|
data[~np.isfinite(data)] = 0.0 |
174
|
|
|
else: |
175
|
|
|
data = self.data.T |
176
|
|
|
|
177
|
|
|
fn = ScaledRegularGridInterpolator( |
178
|
|
|
grid_pix, data, fill_value=fill_value, bounds_error=False, method=method |
179
|
|
|
) |
180
|
|
|
return fn(tuple(pix), clip=False) |
181
|
|
|
|
182
|
|
|
def _interp_by_pix_map_coordinates(self, pix, order=1): |
183
|
|
|
pix = tuple( |
184
|
|
|
[ |
185
|
|
|
np.array(x, ndmin=1) |
186
|
|
|
if not isinstance(x, np.ndarray) or x.ndim == 0 |
187
|
|
|
else x |
188
|
|
|
for x in pix |
189
|
|
|
] |
190
|
|
|
) |
191
|
|
|
return scipy.ndimage.map_coordinates( |
192
|
|
|
self.data.T, pix, order=order, mode="nearest" |
193
|
|
|
) |
194
|
|
|
|
195
|
|
|
def _interp_by_coord_griddata(self, coords, interp=None): |
196
|
|
|
order = interp_to_order(interp) |
197
|
|
|
method_lookup = {0: "nearest", 1: "linear", 3: "cubic"} |
198
|
|
|
method = method_lookup.get(order, None) |
199
|
|
|
if method is None: |
200
|
|
|
raise ValueError("Invalid interp: {!r}".format(interp)) |
201
|
|
|
|
202
|
|
|
grid_coords = tuple(self.geom.get_coord(flat=True)) |
203
|
|
|
data = self.data[np.isfinite(self.data)] |
204
|
|
|
vals = scipy.interpolate.griddata( |
205
|
|
|
grid_coords, data, tuple(coords), method=method |
206
|
|
|
) |
207
|
|
|
|
208
|
|
|
m = ~np.isfinite(vals) |
209
|
|
|
if np.any(m): |
210
|
|
|
vals_fill = scipy.interpolate.griddata( |
211
|
|
|
grid_coords, data, tuple([c[m] for c in coords]), method="nearest" |
212
|
|
|
) |
213
|
|
|
vals[m] = vals_fill |
214
|
|
|
|
215
|
|
|
return vals |
216
|
|
|
|
217
|
|
|
def fill_by_idx(self, idx, weights=None): |
218
|
|
|
idx = pix_tuple_to_idx(idx) |
219
|
|
|
msk = np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0) |
220
|
|
|
idx = [t[msk] for t in idx] |
221
|
|
|
|
222
|
|
|
if weights is not None: |
223
|
|
|
if isinstance(weights, u.Quantity): |
224
|
|
|
weights = weights.to_value(self.unit) |
225
|
|
|
weights = weights[msk] |
226
|
|
|
|
227
|
|
|
idx = np.ravel_multi_index(idx, self.data.T.shape) |
228
|
|
|
idx, idx_inv = np.unique(idx, return_inverse=True) |
229
|
|
|
weights = np.bincount(idx_inv, weights=weights).astype(self.data.dtype) |
230
|
|
|
self.data.T.flat[idx] += weights |
231
|
|
|
|
232
|
|
|
def set_by_idx(self, idx, vals): |
233
|
|
|
idx = pix_tuple_to_idx(idx) |
234
|
|
|
self.data.T[idx] = vals |
235
|
|
|
|
236
|
|
|
def sum_over_axes(self, keepdims=False): |
237
|
|
|
"""To sum map values over all non-spatial axes. |
238
|
|
|
|
239
|
|
|
Parameters |
240
|
|
|
---------- |
241
|
|
|
keepdims : bool, optional |
242
|
|
|
If this is set to true, the axes which are summed over are left in |
243
|
|
|
the map with a single bin |
244
|
|
|
|
245
|
|
|
Returns |
246
|
|
|
------- |
247
|
|
|
map_out : WcsNDMap |
248
|
|
|
Map with non-spatial axes summed over |
249
|
|
|
""" |
250
|
|
|
axis = tuple(range(self.data.ndim - 2)) |
251
|
|
|
geom = self.geom.to_image() |
252
|
|
|
if keepdims: |
253
|
|
|
for ax in self.geom.axes: |
254
|
|
|
geom = geom.to_cube([ax.squash()]) |
255
|
|
|
data = np.nansum(self.data, axis=axis, keepdims=keepdims) |
256
|
|
|
# TODO: summing over the axis can change the unit, handle this correctly |
257
|
|
|
return self._init_copy(geom=geom, data=data) |
258
|
|
|
|
259
|
|
|
def _reproject_to_wcs(self, geom, mode="interp", order=1): |
260
|
|
|
from reproject import reproject_interp, reproject_exact |
261
|
|
|
|
262
|
|
|
data = np.empty(geom.data_shape) |
263
|
|
|
|
264
|
|
|
for img, idx in self.iter_by_image(): |
265
|
|
|
# TODO: Create WCS object for image plane if |
266
|
|
|
# multi-resolution geom |
267
|
|
|
shape_out = geom.get_image_shape(idx)[::-1] |
268
|
|
|
|
269
|
|
|
if self.geom.projection == "CAR" and self.geom.is_allsky: |
270
|
|
|
vals, footprint = reproject_car_to_wcs( |
271
|
|
|
(img, self.geom.wcs), geom.wcs, shape_out=shape_out |
272
|
|
|
) |
273
|
|
|
elif mode == "interp": |
274
|
|
|
vals, footprint = reproject_interp( |
275
|
|
|
(img, self.geom.wcs), geom.wcs, shape_out=shape_out |
276
|
|
|
) |
277
|
|
|
elif mode == "exact": |
278
|
|
|
vals, footprint = reproject_exact( |
279
|
|
|
(img, self.geom.wcs), geom.wcs, shape_out=shape_out |
280
|
|
|
) |
281
|
|
|
else: |
282
|
|
|
raise TypeError( |
283
|
|
|
"mode must be 'interp' or 'exact'. Got: {!r}".format(mode) |
284
|
|
|
) |
285
|
|
|
|
286
|
|
|
data[idx] = vals |
287
|
|
|
|
288
|
|
|
return self._init_copy(geom=geom, data=data) |
289
|
|
|
|
290
|
|
|
def _reproject_to_hpx(self, geom, mode="interp", order=1): |
291
|
|
|
from reproject import reproject_to_healpix |
292
|
|
|
|
293
|
|
|
data = np.empty(geom.data_shape) |
294
|
|
|
coordsys = "galactic" if geom.coordsys == "GAL" else "icrs" |
295
|
|
|
|
296
|
|
|
for img, idx in self.iter_by_image(): |
297
|
|
|
# TODO: For partial-sky HPX we need to map from full- to |
298
|
|
|
# partial-sky indices |
299
|
|
|
if self.geom.projection == "CAR" and self.geom.is_allsky: |
300
|
|
|
vals, footprint = reproject_car_to_hpx( |
301
|
|
|
(img, self.geom.wcs), |
302
|
|
|
coordsys, |
303
|
|
|
nside=geom.nside, |
304
|
|
|
nested=geom.nest, |
305
|
|
|
order=order, |
306
|
|
|
) |
307
|
|
|
else: |
308
|
|
|
vals, footprint = reproject_to_healpix( |
309
|
|
|
(img, self.geom.wcs), |
310
|
|
|
coordsys, |
311
|
|
|
nside=geom.nside, |
312
|
|
|
nested=geom.nest, |
313
|
|
|
order=order, |
314
|
|
|
) |
315
|
|
|
data[idx] = vals |
316
|
|
|
|
317
|
|
|
return self._init_copy(geom=geom, data=data) |
318
|
|
|
|
319
|
|
|
def pad(self, pad_width, mode="constant", cval=0, order=1): |
320
|
|
|
if np.isscalar(pad_width): |
321
|
|
|
pad_width = (pad_width, pad_width) |
322
|
|
|
pad_width += (0,) * (self.geom.ndim - 2) |
323
|
|
|
|
324
|
|
|
geom = self.geom.pad(pad_width[:2]) |
325
|
|
|
if self.geom.is_regular and mode != "interp": |
326
|
|
|
return self._pad_np(geom, pad_width, mode, cval) |
327
|
|
|
else: |
328
|
|
|
return self._pad_coadd(geom, pad_width, mode, cval, order) |
329
|
|
|
|
330
|
|
|
def _pad_np(self, geom, pad_width, mode, cval): |
331
|
|
|
"""Pad a map using ``numpy.pad``. |
332
|
|
|
|
333
|
|
|
This method only works for regular geometries but should be more |
334
|
|
|
efficient when working with large maps. |
335
|
|
|
""" |
336
|
|
|
kwargs = {} |
337
|
|
|
if mode == "constant": |
338
|
|
|
kwargs["constant_values"] = cval |
339
|
|
|
|
340
|
|
|
pad_width = [(t, t) for t in pad_width] |
341
|
|
|
data = np.pad(self.data, pad_width[::-1], mode) |
342
|
|
|
return self._init_copy(geom=geom, data=data) |
343
|
|
|
|
344
|
|
|
def _pad_coadd(self, geom, pad_width, mode, cval, order): |
345
|
|
|
"""Pad a map manually by coadding the original map with the new map.""" |
346
|
|
|
idx_in = self.geom.get_idx(flat=True) |
347
|
|
|
idx_in = tuple([t + w for t, w in zip(idx_in, pad_width)])[::-1] |
348
|
|
|
idx_out = geom.get_idx(flat=True)[::-1] |
349
|
|
|
map_out = self._init_copy(geom=geom, data=None) |
350
|
|
|
map_out.coadd(self) |
351
|
|
|
|
352
|
|
|
if mode == "constant": |
353
|
|
|
pad_msk = np.zeros_like(map_out.data, dtype=bool) |
354
|
|
|
pad_msk[idx_out] = True |
355
|
|
|
pad_msk[idx_in] = False |
356
|
|
|
map_out.data[pad_msk] = cval |
357
|
|
|
elif mode == "interp": |
358
|
|
|
coords = geom.pix_to_coord(idx_out[::-1]) |
359
|
|
|
m = self.geom.contains(coords) |
360
|
|
|
coords = tuple([c[~m] for c in coords]) |
361
|
|
|
vals = self.interp_by_coord(coords, interp=order) |
362
|
|
|
map_out.set_by_coord(coords, vals) |
363
|
|
|
else: |
364
|
|
|
raise ValueError("Invalid mode: {!r}".format(mode)) |
365
|
|
|
|
366
|
|
|
return map_out |
367
|
|
|
|
368
|
|
|
def crop(self, crop_width): |
369
|
|
|
if np.isscalar(crop_width): |
370
|
|
|
crop_width = (crop_width, crop_width) |
371
|
|
|
|
372
|
|
|
geom = self.geom.crop(crop_width) |
373
|
|
|
if self.geom.is_regular: |
374
|
|
|
slices = [slice(None)] * len(self.geom.axes) |
375
|
|
|
slices += [ |
376
|
|
|
slice(crop_width[1], int(self.geom.npix[1] - crop_width[1])), |
377
|
|
|
slice(crop_width[0], int(self.geom.npix[0] - crop_width[0])), |
378
|
|
|
] |
379
|
|
|
data = self.data[tuple(slices)] |
380
|
|
|
map_out = self._init_copy(geom=geom, data=data) |
381
|
|
|
else: |
382
|
|
|
# FIXME: This could be done more efficiently by |
383
|
|
|
# constructing the appropriate slices for each image plane |
384
|
|
|
map_out = self._init_copy(geom=geom, data=None) |
385
|
|
|
map_out.coadd(self) |
386
|
|
|
|
387
|
|
|
return map_out |
388
|
|
|
|
389
|
|
|
def upsample(self, factor, order=0, preserve_counts=True, axis=None): |
390
|
|
|
geom = self.geom.upsample(factor, axis=axis) |
391
|
|
|
idx = geom.get_idx() |
392
|
|
|
|
393
|
|
|
if axis is None: |
394
|
|
|
pix = ( |
395
|
|
|
(idx[0] - 0.5 * (factor - 1)) / factor, |
396
|
|
|
(idx[1] - 0.5 * (factor - 1)) / factor, |
397
|
|
|
) + idx[2:] |
398
|
|
|
else: |
399
|
|
|
pix = list(idx) |
400
|
|
|
idx_ax = self.geom.get_axis_index_by_name(axis) |
401
|
|
|
pix[idx_ax] = (pix[idx_ax] - 0.5 * (factor - 1)) / factor |
402
|
|
|
|
403
|
|
|
data = scipy.ndimage.map_coordinates( |
404
|
|
|
self.data.T, tuple(pix), order=order, mode="nearest" |
405
|
|
|
) |
406
|
|
|
|
407
|
|
|
if preserve_counts: |
408
|
|
|
if axis is None: |
409
|
|
|
data /= factor ** 2 |
410
|
|
|
else: |
411
|
|
|
data /= factor |
412
|
|
|
|
413
|
|
|
return self._init_copy(geom=geom, data=data) |
414
|
|
|
|
415
|
|
|
def downsample(self, factor, preserve_counts=True, axis=None): |
416
|
|
|
geom = self.geom.downsample(factor, axis=axis) |
417
|
|
|
if axis is None: |
418
|
|
|
block_size = (factor, factor) + (1,) * len(self.geom.axes) |
419
|
|
|
else: |
420
|
|
|
block_size = [1] * self.data.ndim |
421
|
|
|
idx = self.geom.get_axis_index_by_name(axis) |
422
|
|
|
block_size[-(idx + 1)] = factor |
423
|
|
|
|
424
|
|
|
func = np.nansum if preserve_counts else np.nanmean |
425
|
|
|
data = block_reduce(self.data, tuple(block_size[::-1]), func=func) |
426
|
|
|
|
427
|
|
|
return self._init_copy(geom=geom, data=data) |
428
|
|
|
|
429
|
|
|
def plot(self, ax=None, fig=None, add_cbar=False, stretch="linear", **kwargs): |
430
|
|
|
""" |
431
|
|
|
Plot image on matplotlib WCS axes. |
432
|
|
|
|
433
|
|
|
Parameters |
434
|
|
|
---------- |
435
|
|
|
ax : `~astropy.visualization.wcsaxes.WCSAxes`, optional |
436
|
|
|
WCS axis object to plot on. |
437
|
|
|
fig : `~matplotlib.figure.Figure` |
438
|
|
|
Figure object. |
439
|
|
|
add_cbar : bool |
440
|
|
|
Add color bar? |
441
|
|
|
stretch : str |
442
|
|
|
Passed to `astropy.visualization.simple_norm`. |
443
|
|
|
**kwargs : dict |
444
|
|
|
Keyword arguments passed to `~matplotlib.pyplot.imshow`. |
445
|
|
|
|
446
|
|
|
Returns |
447
|
|
|
------- |
448
|
|
|
fig : `~matplotlib.figure.Figure` |
449
|
|
|
Figure object. |
450
|
|
|
ax : `~astropy.visualization.wcsaxes.WCSAxes` |
451
|
|
|
WCS axis object |
452
|
|
|
cbar : `~matplotlib.colorbar.Colorbar` or None |
453
|
|
|
Colorbar object. |
454
|
|
|
""" |
455
|
|
|
import matplotlib.pyplot as plt |
456
|
|
|
from astropy.visualization import simple_norm |
457
|
|
|
from astropy.visualization.wcsaxes.frame import EllipticalFrame |
458
|
|
|
|
459
|
|
|
if not self.geom.is_image: |
460
|
|
|
raise TypeError("Use .plot_interactive() for Map dimension > 2") |
461
|
|
|
|
462
|
|
|
if fig is None: |
463
|
|
|
fig = plt.gcf() |
464
|
|
|
|
465
|
|
|
if ax is None: |
466
|
|
|
if self.geom.is_allsky: |
467
|
|
|
ax = fig.add_subplot( |
468
|
|
|
1, 1, 1, projection=self.geom.wcs, frame_class=EllipticalFrame |
469
|
|
|
) |
470
|
|
|
else: |
471
|
|
|
ax = fig.add_subplot(1, 1, 1, projection=self.geom.wcs) |
472
|
|
|
|
473
|
|
|
data = self.data.astype(float) |
474
|
|
|
|
475
|
|
|
kwargs.setdefault("interpolation", "nearest") |
476
|
|
|
kwargs.setdefault("origin", "lower") |
477
|
|
|
kwargs.setdefault("cmap", "afmhot") |
478
|
|
|
|
479
|
|
|
norm = simple_norm(data[np.isfinite(data)], stretch) |
480
|
|
|
kwargs.setdefault("norm", norm) |
481
|
|
|
|
482
|
|
|
caxes = ax.imshow(data, **kwargs) |
483
|
|
|
cbar = fig.colorbar(caxes, ax=ax, label=str(self.unit)) if add_cbar else None |
484
|
|
|
|
485
|
|
|
if self.geom.is_allsky: |
486
|
|
|
ax = self._plot_format_allsky(ax) |
487
|
|
|
else: |
488
|
|
|
ax = self._plot_format(ax) |
489
|
|
|
|
490
|
|
|
# without this the axis limits are changed when calling scatter |
491
|
|
|
ax.autoscale(enable=False) |
492
|
|
|
return fig, ax, cbar |
493
|
|
|
|
494
|
|
|
def _plot_format(self, ax): |
495
|
|
|
try: |
496
|
|
|
ax.coords["glon"].set_axislabel("Galactic Longitude") |
497
|
|
|
ax.coords["glat"].set_axislabel("Galactic Latitude") |
498
|
|
|
except KeyError: |
499
|
|
|
ax.coords["ra"].set_axislabel("Right Ascension") |
500
|
|
|
ax.coords["dec"].set_axislabel("Declination") |
501
|
|
|
except AttributeError: |
502
|
|
|
log.info("Can't set coordinate axes. No WCS information available.") |
503
|
|
|
return ax |
504
|
|
|
|
505
|
|
|
def _plot_format_allsky(self, ax): |
506
|
|
|
# Remove frame |
507
|
|
|
ax.coords.frame.set_linewidth(0) |
508
|
|
|
|
509
|
|
|
# Set plot axis limits |
510
|
|
|
ymax, xmax = self.data.shape |
511
|
|
|
xmargin, _ = self.geom.coord_to_pix({"lon": 180, "lat": 0}) |
512
|
|
|
_, ymargin = self.geom.coord_to_pix({"lon": 0, "lat": -90}) |
513
|
|
|
|
514
|
|
|
ax.set_xlim(xmargin, xmax - xmargin) |
515
|
|
|
ax.set_ylim(ymargin, ymax - ymargin) |
516
|
|
|
|
517
|
|
|
ax.text(0, ymax, self.geom.coordsys + " coords") |
518
|
|
|
|
519
|
|
|
# Grid and ticks |
520
|
|
|
glon_spacing, glat_spacing = 45, 15 |
521
|
|
|
lon, lat = ax.coords |
522
|
|
|
lon.set_ticks(spacing=glon_spacing * u.deg, color="w", alpha=0.8) |
523
|
|
|
lat.set_ticks(spacing=glat_spacing * u.deg) |
524
|
|
|
lon.set_ticks_visible(False) |
525
|
|
|
|
526
|
|
|
lon.set_ticklabel(color="w", alpha=0.8) |
527
|
|
|
lon.grid(alpha=0.2, linestyle="solid", color="w") |
528
|
|
|
lat.grid(alpha=0.2, linestyle="solid", color="w") |
529
|
|
|
return ax |
530
|
|
|
|
531
|
|
|
def smooth(self, width, kernel="gauss", **kwargs): |
532
|
|
|
"""Smooth the map. |
533
|
|
|
|
534
|
|
|
Iterates over 2D image planes, processing one at a time. |
535
|
|
|
|
536
|
|
|
Parameters |
537
|
|
|
---------- |
538
|
|
|
width : `~astropy.units.Quantity`, str or float |
539
|
|
|
Smoothing width given as quantity or float. If a float is given it |
540
|
|
|
interpreted as smoothing width in pixels. If an (angular) quantity |
541
|
|
|
is given it converted to pixels using ``geom.wcs.wcs.cdelt``. |
542
|
|
|
It corresponds to the standard deviation in case of a Gaussian kernel, |
543
|
|
|
the radius in case of a disk kernel, and the side length in case |
544
|
|
|
of a box kernel. |
545
|
|
|
kernel : {'gauss', 'disk', 'box'} |
546
|
|
|
Kernel shape |
547
|
|
|
kwargs : dict |
548
|
|
|
Keyword arguments passed to `~scipy.ndimage.uniform_filter` |
549
|
|
|
('box'), `~scipy.ndimage.gaussian_filter` ('gauss') or |
550
|
|
|
`~scipy.ndimage.convolve` ('disk'). |
551
|
|
|
|
552
|
|
|
Returns |
553
|
|
|
------- |
554
|
|
|
image : `WcsNDMap` |
555
|
|
|
Smoothed image (a copy, the original object is unchanged). |
556
|
|
|
""" |
557
|
|
|
if isinstance(width, (u.Quantity, str)): |
558
|
|
|
width = u.Quantity(width) / self.geom.pixel_scales.mean() |
559
|
|
|
width = width.to_value("") |
560
|
|
|
|
561
|
|
|
smoothed_data = np.empty(self.data.shape, dtype=float) |
562
|
|
|
|
563
|
|
|
for img, idx in self.iter_by_image(): |
564
|
|
|
img = img.astype(float) |
565
|
|
|
if kernel == "gauss": |
566
|
|
|
data = scipy.ndimage.gaussian_filter(img, width, **kwargs) |
567
|
|
|
elif kernel == "disk": |
568
|
|
|
disk = Tophat2DKernel(width) |
569
|
|
|
disk.normalize("integral") |
570
|
|
|
data = scipy.ndimage.convolve(img, disk.array, **kwargs) |
571
|
|
|
elif kernel == "box": |
572
|
|
|
data = scipy.ndimage.uniform_filter(img, width, **kwargs) |
573
|
|
|
else: |
574
|
|
|
raise ValueError("Invalid kernel: {!r}".format(kernel)) |
575
|
|
|
smoothed_data[idx] = data |
576
|
|
|
|
577
|
|
|
return self._init_copy(data=smoothed_data) |
578
|
|
|
|
579
|
|
|
def get_spectrum(self, region=None, func=np.nansum): |
580
|
|
|
"""Extract spectrum in a given region. |
581
|
|
|
|
582
|
|
|
The spectrum can be computed by summing (or, more generally, applying `func`) |
583
|
|
|
along the spatial axes in each energy bin. This occurs only inside the `region`, |
584
|
|
|
which by default is assumed to be the whole spatial extension of the map. |
585
|
|
|
|
586
|
|
|
Parameters |
587
|
|
|
---------- |
588
|
|
|
region: `~regions.Region` |
589
|
|
|
Region (pixel or sky regions accepted). |
590
|
|
|
func : `~numpy.ufunc` |
591
|
|
|
Function to reduce the data. |
592
|
|
|
|
593
|
|
|
Returns |
594
|
|
|
------- |
595
|
|
|
spectrum : `CountsSpectrum` |
596
|
|
|
Spectrum in the given region. |
597
|
|
|
""" |
598
|
|
|
from gammapy.spectrum import CountsSpectrum |
599
|
|
|
|
600
|
|
|
energy_axis = self.geom.get_axis_by_name("energy") |
601
|
|
|
|
602
|
|
|
if region: |
603
|
|
|
mask = self.geom.region_mask([region]) |
604
|
|
|
data = self.data[mask].reshape(energy_axis.nbin, -1) |
605
|
|
|
data = func(data, axis=1) |
606
|
|
|
else: |
607
|
|
|
data = func(self.data, axis=(1, 2)) |
608
|
|
|
|
609
|
|
|
edges = energy_axis.edges |
610
|
|
|
return CountsSpectrum( |
611
|
|
|
data=data, energy_lo=edges[:-1], energy_hi=edges[1:], unit=self.unit |
612
|
|
|
) |
613
|
|
|
|
614
|
|
|
def convolve(self, kernel, use_fft=True, **kwargs): |
615
|
|
|
""" |
616
|
|
|
Convolve map with a kernel. |
617
|
|
|
|
618
|
|
|
If the kernel is two dimensional, it is applied to all image planes likewise. |
619
|
|
|
If the kernel is higher dimensional it must match the map in the number of |
620
|
|
|
dimensions and the corresponding kernel is selected for every image plane. |
621
|
|
|
|
622
|
|
|
Parameters |
623
|
|
|
---------- |
624
|
|
|
kernel : `~gammapy.cube.PSFKernel` or `numpy.ndarray` |
625
|
|
|
Convolution kernel. |
626
|
|
|
use_fft : bool |
627
|
|
|
Use `scipy.signal.fftconvolve` or `scipy.ndimage.convolve`. |
628
|
|
|
kwargs : dict |
629
|
|
|
Keyword arguments passed to `scipy.signal.fftconvolve` or |
630
|
|
|
`scipy.ndimage.convolve`. |
631
|
|
|
|
632
|
|
|
Returns |
633
|
|
|
------- |
634
|
|
|
map : `WcsNDMap` |
635
|
|
|
Convolved map. |
636
|
|
|
""" |
637
|
|
|
from gammapy.cube import PSFKernel |
638
|
|
|
|
639
|
|
|
conv_function = scipy.signal.fftconvolve if use_fft else scipy.ndimage.convolve |
640
|
|
|
convolved_data = np.empty(self.data.shape, dtype=np.float32) |
641
|
|
|
if use_fft: |
642
|
|
|
kwargs.setdefault("mode", "same") |
643
|
|
|
|
644
|
|
|
if isinstance(kernel, PSFKernel): |
645
|
|
|
kmap = kernel.psf_kernel_map |
646
|
|
|
if not np.allclose( |
647
|
|
|
self.geom.pixel_scales.deg, kmap.geom.pixel_scales.deg, rtol=1e-5 |
648
|
|
|
): |
649
|
|
|
raise ValueError("Pixel size of kernel and map not compatible.") |
650
|
|
|
kernel = kmap.data |
651
|
|
|
|
652
|
|
|
for img, idx in self.iter_by_image(): |
653
|
|
|
idx = Ellipsis if kernel.ndim == 2 else idx |
654
|
|
|
convolved_data[idx] = conv_function(img, kernel[idx], **kwargs) |
655
|
|
|
|
656
|
|
|
return self._init_copy(data=convolved_data) |
657
|
|
|
|
658
|
|
|
def cutout(self, position, width, mode="trim"): |
659
|
|
|
""" |
660
|
|
|
Create a cutout around a given position. |
661
|
|
|
|
662
|
|
|
Parameters |
663
|
|
|
---------- |
664
|
|
|
position : `~astropy.coordinates.SkyCoord` |
665
|
|
|
Center position of the cutout region. |
666
|
|
|
width : tuple of `~astropy.coordinates.Angle` |
667
|
|
|
Angular sizes of the region in (lon, lat) in that specific order. |
668
|
|
|
If only one value is passed, a square region is extracted. |
669
|
|
|
mode : {'trim', 'partial', 'strict'} |
670
|
|
|
Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`. |
671
|
|
|
|
672
|
|
|
Returns |
673
|
|
|
------- |
674
|
|
|
cutout : `~gammapy.maps.WcsNDMap` |
675
|
|
|
Cutout map |
676
|
|
|
""" |
677
|
|
|
width = _check_width(width) |
678
|
|
|
idx = (0,) * len(self.geom.axes) |
679
|
|
|
c2d = Cutout2D( |
680
|
|
|
data=self.data[idx], |
681
|
|
|
wcs=self.geom.wcs, |
682
|
|
|
position=position, |
683
|
|
|
# Cutout2D takes size with order (lat, lon) |
684
|
|
|
size=width[::-1] * u.deg, |
685
|
|
|
mode=mode, |
686
|
|
|
) |
687
|
|
|
|
688
|
|
|
# Create the slices with the non-spatial axis |
689
|
|
|
cutout_slices = Ellipsis, c2d.slices_original[0], c2d.slices_original[1] |
690
|
|
|
|
691
|
|
|
geom = WcsGeom(c2d.wcs, c2d.shape[::-1], axes=self.geom.axes) |
692
|
|
|
data = self.data[cutout_slices] |
693
|
|
|
|
694
|
|
|
return self._init_copy(geom=geom, data=data) |
695
|
|
|
|
696
|
|
|
def sample_coord(self, n_events, random_state=0): |
697
|
|
|
"""Sample position and energy of events. |
698
|
|
|
|
699
|
|
|
Parameters |
700
|
|
|
---------- |
701
|
|
|
n_events : int |
702
|
|
|
Number of events to sample. |
703
|
|
|
random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`} |
704
|
|
|
Defines random number generator initialisation. |
705
|
|
|
Passed to `~gammapy.utils.random.get_random_state`. |
706
|
|
|
|
707
|
|
|
Returns |
708
|
|
|
------- |
709
|
|
|
coords : `~gammapy.maps.MapCoord` object. |
710
|
|
|
Sequence of coordinates and energies of the sampled events. |
711
|
|
|
""" |
712
|
|
|
|
713
|
|
|
random_state = get_random_state(random_state) |
714
|
|
|
sampler = InverseCDFSampler(pdf=self.data, random_state=random_state) |
715
|
|
|
|
716
|
|
|
coords_pix = sampler.sample(n_events) |
717
|
|
|
coords = self.geom.pix_to_coord(coords_pix[::-1]) |
718
|
|
|
|
719
|
|
|
# TODO: pix_to_coord should return a MapCoord object |
720
|
|
|
axes_names = ["lon", "lat"] + [ax.name for ax in self.geom.axes] |
721
|
|
|
cdict = OrderedDict(zip(axes_names, coords)) |
722
|
|
|
cdict["energy"] *= self.geom.get_axis_by_name("energy").unit |
723
|
|
|
|
724
|
|
|
return MapCoord.create(cdict, coordsys=self.geom.coordsys) |
725
|
|
|
|