Passed
Pull Request — master (#2536)
by Axel
03:12
created

gammapy.maps.tests.test_wcs.test_cutout_info()   A

Complexity

Conditions 1

Size

Total Lines 25
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 25
rs 9.4
c 0
b 0
f 0
cc 1
nop 0
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import pytest
3
import numpy as np
4
from numpy.testing import assert_allclose
5
import astropy.units as u
6
from astropy.coordinates import Angle, SkyCoord
7
from astropy.io import fits
8
from gammapy.maps import Map, MapAxis, WcsGeom
9
from gammapy.maps.wcs import _check_width
10
11
axes1 = [MapAxis(np.logspace(0.0, 3.0, 3), interp="log", name="energy")]
12
axes2 = [
13
    MapAxis(np.logspace(0.0, 3.0, 3), interp="log", name="energy"),
14
    MapAxis(np.logspace(1.0, 3.0, 4), interp="lin"),
15
]
16
skydir = SkyCoord(110.0, 75.0, unit="deg", frame="icrs")
17
18
wcs_allsky_test_geoms = [
19
    (None, 10.0, "GAL", "AIT", skydir, None),
20
    (None, 10.0, "GAL", "AIT", skydir, axes1),
21
    (None, [10.0, 20.0], "GAL", "AIT", skydir, axes1),
22
    (None, 10.0, "GAL", "AIT", skydir, axes2),
23
    (None, [[10.0, 20.0, 30.0], [10.0, 20.0, 30.0]], "GAL", "AIT", skydir, axes2),
24
]
25
26
wcs_partialsky_test_geoms = [
27
    (10, 0.1, "GAL", "AIT", skydir, None),
28
    (10, 0.1, "GAL", "AIT", skydir, axes1),
29
    (10, [0.1, 0.2], "GAL", "AIT", skydir, axes1),
30
]
31
32
wcs_test_geoms = wcs_allsky_test_geoms + wcs_partialsky_test_geoms
33
34
35
@pytest.mark.parametrize(
36
    ("npix", "binsz", "coordsys", "proj", "skydir", "axes"), wcs_test_geoms
37
)
38
def test_wcsgeom_init(npix, binsz, coordsys, proj, skydir, axes):
39
    WcsGeom.create(
40
        npix=npix, binsz=binsz, skydir=skydir, proj=proj, coordsys=coordsys, axes=axes
41
    )
42
43
44
@pytest.mark.parametrize(
45
    ("npix", "binsz", "coordsys", "proj", "skydir", "axes"), wcs_test_geoms
46
)
47
def test_wcsgeom_get_pix(npix, binsz, coordsys, proj, skydir, axes):
48
    geom = WcsGeom.create(
49
        npix=npix, binsz=binsz, skydir=skydir, proj=proj, coordsys=coordsys, axes=axes
50
    )
51
    pix = geom.get_idx()
52
    if axes is not None:
53
        idx = tuple([1] * len(axes))
54
        pix_img = geom.get_idx(idx=idx)
55
        m = np.all(np.stack([x == y for x, y in zip(idx, pix[2:])]), axis=0)
56
        m2 = pix_img[0] != -1
57
        assert_allclose(pix[0][m], np.ravel(pix_img[0][m2]))
58
        assert_allclose(pix[1][m], np.ravel(pix_img[1][m2]))
59
60
61
@pytest.mark.parametrize(
62
    ("npix", "binsz", "coordsys", "proj", "skydir", "axes"), wcs_test_geoms
63
)
64
def test_wcsgeom_test_pix_to_coord(npix, binsz, coordsys, proj, skydir, axes):
65
    geom = WcsGeom.create(
66
        npix=npix, binsz=binsz, skydir=skydir, proj=proj, coordsys=coordsys, axes=axes
67
    )
68
    assert_allclose(geom.get_coord()[0], geom.pix_to_coord(geom.get_idx())[0])
69
70
71
@pytest.mark.parametrize(
72
    ("npix", "binsz", "coordsys", "proj", "skydir", "axes"), wcs_test_geoms
73
)
74
def test_wcsgeom_test_coord_to_idx(npix, binsz, coordsys, proj, skydir, axes):
75
    geom = WcsGeom.create(
76
        npix=npix, binsz=binsz, proj=proj, coordsys=coordsys, axes=axes
77
    )
78
    assert_allclose(geom.get_idx()[0], geom.coord_to_idx(geom.get_coord())[0])
79
80
    if not geom.is_allsky:
81
        coords = geom.center_coord[:2] + tuple([ax.center[0] for ax in geom.axes])
82
        coords[0][...] += 2.0 * np.max(geom.width[0])
83
        idx = geom.coord_to_idx(coords)
84
        assert_allclose(np.full_like(coords[0].value, -1, dtype=int), idx[0])
85
        idx = geom.coord_to_idx(coords, clip=True)
86
        assert np.all(
87
            np.not_equal(np.full_like(coords[0].value, -1, dtype=int), idx[0])
88
        )
89
90
91
@pytest.mark.parametrize(
92
    ("npix", "binsz", "coordsys", "proj", "skydir", "axes"), wcs_test_geoms
93
)
94
def test_wcsgeom_read_write(tmp_path, npix, binsz, coordsys, proj, skydir, axes):
95
    geom0 = WcsGeom.create(
96
        npix=npix, binsz=binsz, proj=proj, coordsys=coordsys, axes=axes
97
    )
98
99
    hdu_bands = geom0.make_bands_hdu(hdu="BANDS")
100
    hdu_prim = fits.PrimaryHDU()
101
    hdu_prim.header.update(geom0.make_header())
102
103
    hdulist = fits.HDUList([hdu_prim, hdu_bands])
104
    hdulist.writeto(tmp_path / "tmp.fits")
105
106
    with fits.open(tmp_path / "tmp.fits", memmap=False) as hdulist:
107
        geom1 = WcsGeom.from_header(hdulist[0].header, hdulist["BANDS"])
108
109
    assert_allclose(geom0.npix, geom1.npix)
110
    assert geom0.coordsys == geom1.coordsys
111
112
113
def test_wcsgeom_to_hdulist():
114
    npix, binsz, coordsys, proj, skydir, axes = wcs_test_geoms[3]
115
    geom = WcsGeom.create(
116
        npix=npix, binsz=binsz, proj=proj, coordsys=coordsys, axes=axes
117
    )
118
119
    hdu = geom.make_bands_hdu(hdu="TEST")
120
    assert hdu.header["AXCOLS1"] == "E_MIN,E_MAX"
121
    assert hdu.header["AXCOLS2"] == "AXIS1_MIN,AXIS1_MAX"
122
123
124
@pytest.mark.parametrize(
125
    ("npix", "binsz", "coordsys", "proj", "skydir", "axes"), wcs_test_geoms
126
)
127
def test_wcsgeom_contains(npix, binsz, coordsys, proj, skydir, axes):
128
    geom = WcsGeom.create(
129
        npix=npix, binsz=binsz, skydir=skydir, proj=proj, coordsys=coordsys, axes=axes
130
    )
131
    coords = geom.get_coord()
132
    m = np.isfinite(coords[0])
133
    coords = [c[m] for c in coords]
134
    assert_allclose(geom.contains(coords), np.ones(coords[0].shape, dtype=bool))
135
136
    if axes is not None:
137
        coords = [c[0] for c in coords[:2]] + [ax.edges[-1] + 1.0 for ax in axes]
138
        assert_allclose(geom.contains(coords), np.zeros((1,), dtype=bool))
139
140
    if not geom.is_allsky:
141
        coords = [0.0, 0.0] + [ax.center[0] for ax in geom.axes]
142
        assert_allclose(geom.contains(coords), np.zeros((1,), dtype=bool))
143
144
145
def test_wcsgeom_solid_angle():
146
    # Test using a CAR projection map with an extra axis
147
    binsz = 1.0 * u.deg
148
    npix = 10
149
    geom = WcsGeom.create(
150
        skydir=(0, 0),
151
        npix=(npix, npix),
152
        binsz=binsz,
153
        coordsys="GAL",
154
        proj="CAR",
155
        axes=[MapAxis.from_edges([0, 2, 3])],
156
    )
157
158
    solid_angle = geom.solid_angle()
159
160
    # Check array size
161
    assert solid_angle.shape == (2, npix, npix)
162
163
    # Test at b = 0 deg
164
    assert solid_angle.unit == "sr"
165
    assert_allclose(solid_angle.value[0, 5, 5], 0.0003046, rtol=1e-3)
166
167
    # Test at b = 5 deg
168
    assert_allclose(solid_angle.value[0, 9, 5], 0.0003038, rtol=1e-3)
169
170
171
def test_wcsgeom_solid_angle_symmetry():
172
    geom = WcsGeom.create(
173
        skydir=(0, 0), coordsys="GAL", npix=(3, 3), binsz=20.0 * u.deg
174
    )
175
176
    sa = geom.solid_angle()
177
178
    assert_allclose(sa[1, :], sa[1, 0])  # Constant along lon
179
    assert_allclose(sa[0, 1], sa[2, 1])  # Symmetric along lat
180
    with pytest.raises(AssertionError):
181
        # Not constant along lat due to changes in solid angle (great circle)
182
        assert_allclose(sa[:, 1], sa[0, 1])
183
184
185
def test_wcsgeom_solid_angle_ait():
186
    # Pixels that don't correspond to locations on ths sky
187
    # should have solid angles set to NaN
188
    ait_geom = WcsGeom.create(
189
        skydir=(0, 0), width=(360, 180), binsz=20, coordsys="GAL", proj="AIT"
190
    )
191
    solid_angle = ait_geom.solid_angle().to_value("deg2")
192
193
    assert_allclose(solid_angle[4, 1], 397.04838)
194
    assert_allclose(solid_angle[4, 16], 397.751841)
195
    assert_allclose(solid_angle[1, 8], 381.556269)
196
    assert_allclose(solid_angle[7, 8], 398.34725)
197
198
    assert np.isnan(solid_angle[0, 0])
199
200
201
def test_wcsgeom_separation():
202
    geom = WcsGeom.create(
203
        skydir=(0, 0),
204
        npix=10,
205
        binsz=0.1,
206
        coordsys="GAL",
207
        proj="CAR",
208
        axes=[MapAxis.from_edges([0, 2, 3])],
209
    )
210
    position = SkyCoord(1, 0, unit="deg", frame="galactic").icrs
211
    separation = geom.separation(position)
212
213
    assert separation.unit == "deg"
214
    assert separation.shape == (10, 10)
215
    assert_allclose(separation.value[0, 0], 0.7106291438079875)
216
217
    # Make sure it also works for 2D maps as input
218
    separation = geom.to_image().separation(position)
219
    assert separation.unit == "deg"
220
    assert separation.shape == (10, 10)
221
    assert_allclose(separation.value[0, 0], 0.7106291438079875)
222
223
224
def test_cutout():
225
    geom = WcsGeom.create(
226
        skydir=(0, 0),
227
        npix=10,
228
        binsz=0.1,
229
        coordsys="GAL",
230
        proj="CAR",
231
        axes=[MapAxis.from_edges([0, 2, 3])],
232
    )
233
    position = SkyCoord(0.1, 0.2, unit="deg", frame="galactic")
234
    cutout_geom = geom.cutout(position=position, width=2 * 0.3 * u.deg, mode="trim")
235
236
    center_coord = cutout_geom.center_coord
237
    assert_allclose(center_coord[0].value, 0.1)
238
    assert_allclose(center_coord[1].value, 0.2)
239
    assert_allclose(center_coord[2].value, 2.0)
240
241
    assert cutout_geom.data_shape == (2, 6, 6)
242
243
244
def test_cutout_info():
245
    geom = WcsGeom.create(
246
        skydir=(0, 0),
247
        npix=10,
248
    )
249
    position = SkyCoord(0, 0, unit="deg")
250
    cutout_geom = geom.cutout(position=position, width="2 deg")
251
    assert cutout_geom.cutout_info["parent-slices"][0].start == 3
252
    assert cutout_geom.cutout_info["parent-slices"][1].start == 3
253
254
    assert cutout_geom.cutout_info["cutout-slices"][0].start == 0
255
    assert cutout_geom.cutout_info["cutout-slices"][1].start == 0
256
257
    header = cutout_geom.make_header()
258
    assert "PSLICE1" in header
259
    assert "PSLICE2" in header
260
    assert "CSLICE1" in header
261
    assert "CSLICE2" in header
262
263
    geom = WcsGeom.from_header(header)
264
    assert geom.cutout_info["parent-slices"][0].start == 3
265
    assert geom.cutout_info["parent-slices"][1].start == 3
266
267
    assert geom.cutout_info["cutout-slices"][0].start == 0
268
    assert geom.cutout_info["cutout-slices"][1].start == 0
269
270
271
def test_wcsgeom_get_coord():
272
    geom = WcsGeom.create(
273
        skydir=(0, 0), npix=(4, 3), binsz=1, coordsys="GAL", proj="CAR"
274
    )
275
    coord = geom.get_coord(mode="edges")
276
    assert_allclose(coord.lon[0, 0].value, 2)
277
    assert coord.lon[0, 0].unit == "deg"
278
    assert_allclose(coord.lat[0, 0].value, -1.5)
279
    assert coord.lat[0, 0].unit == "deg"
280
281
282
def test_wcsgeom_instance_cache():
283
    geom_1 = WcsGeom.create(npix=(3, 3))
284
    geom_2 = WcsGeom.create(npix=(3, 3))
285
286
    coord_1, coord_2 = geom_1.get_coord(), geom_2.get_coord()
287
288
    assert geom_1.get_coord.cache_info().misses == 1
289
    assert geom_2.get_coord.cache_info().misses == 1
290
291
    coord_1_cached, coord_2_cached = geom_1.get_coord(), geom_2.get_coord()
292
293
    assert geom_1.get_coord.cache_info().hits == 1
294
    assert geom_2.get_coord.cache_info().hits == 1
295
296
    assert geom_1.get_coord.cache_info().currsize == 1
297
    assert geom_2.get_coord.cache_info().currsize == 1
298
299
    assert id(coord_1) == id(coord_1_cached)
300
    assert id(coord_2) == id(coord_2_cached)
301
302
303
def test_wcsgeom_get_pix_coords():
304
    geom = WcsGeom.create(
305
        skydir=(0, 0), npix=(4, 3), binsz=1, coordsys="GAL", proj="CAR", axes=axes1
306
    )
307
    idx_center = geom.get_pix(mode="center")
308
309
    for idx in idx_center:
310
        assert idx.shape == (2, 3, 4)
311
        assert_allclose(idx[0, 0, 0], 0)
312
313
    idx_edge = geom.get_pix(mode="edges")
314
    for idx, desired in zip(idx_edge, [-0.5, -0.5, 0]):
315
        assert idx.shape == (2, 4, 5)
316
        assert_allclose(idx[0, 0, 0], desired)
317
318
319
def test_geom_repr():
320
    geom = WcsGeom.create(
321
        skydir=(0, 0), npix=(10, 4), binsz=50, coordsys="GAL", proj="AIT"
322
    )
323
    assert geom.__class__.__name__ in repr(geom)
324
325
326
def test_geom_refpix():
327
    refpix = (400, 300)
328
    geom = WcsGeom.create(
329
        skydir=(0, 0), npix=(800, 600), refpix=refpix, binsz=0.1, coordsys="GAL"
330
    )
331
    assert_allclose(geom.wcs.wcs.crpix, refpix)
332
333
334
def test_region_mask():
335
    from regions import CircleSkyRegion
336
337
    geom = WcsGeom.create(npix=(3, 3), binsz=2, proj="CAR")
338
339
    r1 = CircleSkyRegion(SkyCoord(0, 0, unit="deg"), 1 * u.deg)
340
    r2 = CircleSkyRegion(SkyCoord(20, 20, unit="deg"), 1 * u.deg)
341
    regions = [r1, r2]
342
343
    mask = geom.region_mask(regions)  # default inside=True
344
    assert mask.dtype == bool
345
    assert np.sum(mask) == 1
346
347
    mask = geom.region_mask(regions, inside=False)
348
    assert np.sum(mask) == 8
349
350
351
def test_energy_mask():
352
    energy_axis = MapAxis.from_nodes(
353
        [1, 10, 100], interp="log", name="energy", unit="TeV"
354
    )
355
    geom = WcsGeom.create(npix=(1, 1), binsz=1, proj="CAR", axes=[energy_axis])
356
357
    mask = geom.energy_mask(emin=3 * u.TeV)
358
    assert not mask[0, 0, 0]
359
    assert mask[1, 0, 0]
360
    assert mask[2, 0, 0]
361
362
    mask = geom.energy_mask(emax=30 * u.TeV)
363
    assert mask[0, 0, 0]
364
    assert mask[1, 0, 0]
365
    assert not mask[2, 0, 0]
366
367
    mask = geom.energy_mask(emin=3 * u.TeV, emax=30 * u.TeV)
368
    assert not mask[0, 0, 0]
369
    assert not mask[-1, 0, 0]
370
    assert mask[1, 0, 0]
371
372
373
@pytest.mark.parametrize(
374
    ("width", "out"),
375
    [
376
        (10, (10, 10)),
377
        ((10 * u.deg).to("rad"), (10, 10)),
378
        ((10, 5), (10, 5)),
379
        (("10 deg", "5 deg"), (10, 5)),
380
        (Angle([10, 5], "deg"), (10, 5)),
381
        ((10 * u.deg, 5 * u.deg), (10, 5)),
382
        ((10, 5) * u.deg, (10, 5)),
383
        ([10, 5], (10, 5)),
384
        (["10 deg", "5 deg"], (10, 5)),
385
        (np.array([10, 5]), (10, 5)),
386
    ],
387
)
388
def test_check_width(width, out):
389
    width = _check_width(width)
390
    assert isinstance(width, tuple)
391
    assert isinstance(width[0], float)
392
    assert isinstance(width[1], float)
393
    assert width == out
394
395
    geom = WcsGeom.create(width=width, binsz=1.0)
396
    assert tuple(geom.npix) == out
397
398
399
def test_check_width_bad_input():
400
    with pytest.raises(IndexError):
401
        _check_width(width=(10,))
402
403
404
def test_get_axis_index_by_name():
405
    e_axis = MapAxis.from_edges([1, 5], name="energy")
406
    geom = WcsGeom.create(width=5, binsz=1.0, axes=[e_axis])
407
    assert geom.get_axis_index_by_name("Energy") == 0
408
    with pytest.raises(ValueError):
409
        geom.get_axis_index_by_name("time")
410
411
412
test_axis1 = [MapAxis(nodes=(1, 2, 3, 4), unit="TeV", node_type="center")]
413
test_axis2 = [
414
    MapAxis(nodes=(1, 2, 3, 4), unit="TeV", node_type="center"),
415
    MapAxis(nodes=(1, 2, 3), unit="TeV", node_type="center"),
416
]
417
418
skydir2 = SkyCoord(110.0, 75.0 + 1e-8, unit="deg", frame="icrs")
419
skydir3 = SkyCoord(110.0, 75.0 + 1e-3, unit="deg", frame="icrs")
420
421
compatibility_test_geoms = [
422
    (10, 0.1, "GAL", "CAR", skydir, test_axis1, True),
423
    (10, 0.1, "GAL", "CAR", skydir2, test_axis1, True),
424
    (10, 0.1, "GAL", "CAR", skydir3, test_axis1, False),
425
    (10, 0.1, "GAL", "TAN", skydir, test_axis1, False),
426
    (8, 0.1, "GAL", "CAR", skydir, test_axis1, False),
427
    (10, 0.1, "GAL", "CAR", skydir, test_axis2, False),
428
    (10, 0.1, "GAL", "CAR", skydir.galactic, test_axis1, True),
429
]
430
431
432
@pytest.mark.parametrize(
433
    ("npix", "binsz", "coordsys", "proj", "skypos", "axes", "result"),
434
    compatibility_test_geoms,
435
)
436
def test_wcs_geom_equal(npix, binsz, coordsys, proj, skypos, axes, result):
437
    geom0 = WcsGeom.create(
438
        skydir=skydir, npix=10, binsz=0.1, proj="CAR", coordsys="GAL", axes=test_axis1
439
    )
440
    geom1 = WcsGeom.create(
441
        skydir=skypos, npix=npix, binsz=binsz, proj=proj, coordsys=coordsys, axes=axes
442
    )
443
444
    assert (geom0 == geom1) is result
445
    assert (geom0 != geom1) is not result
446
447
448
@pytest.mark.parametrize("node_type", ["edges", "center"])
449
@pytest.mark.parametrize("interp", ["log", "lin", "sqrt"])
450
def test_read_write(tmp_path, node_type, interp):
451
    # Regression test for MapAxis interp and node_type FITS serialization
452
    # https://github.com/gammapy/gammapy/issues/1887
453
    e_ax = MapAxis([1, 2], interp, "energy", node_type, "TeV")
454
    t_ax = MapAxis([3, 4], interp, "time", node_type, "s")
455
    m = Map.create(binsz=1, npix=10, axes=[e_ax, t_ax], unit="m2")
456
457
    # Check what Gammapy writes in the FITS header
458
    header = m.make_hdu().header
459
    assert header["INTERP1"] == interp
460
    assert header["INTERP2"] == interp
461
462
    # Check that all MapAxis properties are preserved on FITS I/O
463
    m.write(tmp_path / "tmp.fits", overwrite=True)
464
    m2 = Map.read(tmp_path / "tmp.fits")
465
    assert m2.geom == m.geom
466
467
468
@pytest.mark.parametrize(
469
    ("npix", "binsz", "coordsys", "proj", "skypos", "axes", "result"),
470
    compatibility_test_geoms,
471
)
472
def test_wcs_geom_to_binsz(npix, binsz, coordsys, proj, skypos, axes, result):
473
    geom = WcsGeom.create(
474
        skydir=skydir, npix=10, binsz=0.1, proj="CAR", coordsys="GAL", axes=test_axis1
475
    )
476
477
    geom_new = geom.to_binsz(binsz=0.5)
478
479
    assert_allclose(geom_new.pixel_scales.value, 0.5)
480