gammapy.maps.tests.test_axes   F
last analyzed

Complexity

Total Complexity 71

Size/Duplication

Total Lines 760
Duplicated Lines 4.87 %

Importance

Changes 0
Metric Value
eloc 523
dl 37
loc 760
rs 2.7199
c 0
b 0
f 0
wmc 71

52 Functions

Rating   Name   Duplication   Size   Complexity  
A test_mapaxis_repr() 0 3 1
A time_interval() 0 6 1
A energy_axis_ref() 0 4 1
A time_intervals() 0 6 1
A test_group_table_outside_range() 0 5 2
A test_downsample_non_regular_nodes() 0 8 1
A test_mapaxis_coord_to_idx() 0 4 1
A test_map_axis_from_energy_units() 0 6 3
A test_downsample() 0 14 1
A test_map_axis_format_plot_xaxis() 0 12 3
A test_rename() 0 11 1
A test_mapaxis_invalid_name() 0 3 2
A test_map_axis_aligned() 0 4 1
A test_mapaxis_pix_to_coord() 0 6 1
A test_downsample_non_regular() 0 8 1
A test_group_tablenergy_edges() 0 15 1
A test_from_table_time_axis() 0 15 1
A test_label_map_axis_coord_to_idx() 0 18 2
A test_axes_basics() 0 27 1
A test_map_with_time_axis() 0 10 1
A test_map_axis_pad() 0 11 1
A test_up_downsample_consistency() 0 5 1
A test_mapaxis_equal() 0 17 1
A test_upsample_non_regular_nodes() 0 8 1
A test_upsample() 0 10 1
A test_mapaxis_from_bounds() 0 8 2
A test_time_map_axis_format_plot_xaxis() 0 15 3
A test_one_bin_nodes() 0 7 1
A test_mixed_axes() 0 31 1
A test_group_table_above_range() 12 12 1
A test_bad_length_sort_time_axis() 0 11 3
A test_group_table_below_range() 12 12 1
A test_coord_to_idx_time_axis() 0 23 1
A test_time_map_axis_from_time_bounds() 0 6 1
A test_squash() 0 10 1
A test_from_time_edges_time_axis() 0 16 1
A test_map_axis_plot_helpers() 0 8 1
A test_slice_squash_time_axis() 0 13 1
A test_single_interval_time_axis() 0 21 1
A test_slice_time_axis() 0 12 1
A test_time_axis() 0 26 2
A test_map_axes_pad() 0 9 1
A test_mapaxis_init_from_edges() 0 9 2
A test_label_map_axis_basics() 0 23 3
A test_upsample_non_regular() 0 8 1
A test_mapaxis_slice() 0 21 1
A test_group_table_basic() 13 13 1
A test_axes_getitem() 0 14 1
A test_incorrect_time_axis() 0 11 3
A test_time_axis_plot_helpers() 0 17 1
A test_mapaxis_from_nodes() 0 9 2
A test_from_gti_time_axis() 0 10 1

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like gammapy.maps.tests.test_axes often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import pytest
3
import numpy as np
4
from numpy.testing import assert_allclose, assert_equal
5
import astropy.units as u
6
from astropy.table import Table
7
from astropy.time import Time
8
from astropy.visualization import quantity_support
9
import matplotlib.pyplot as plt
10
from gammapy.data import GTI
11
from gammapy.maps import LabelMapAxis, MapAxes, MapAxis, RegionNDMap, TimeMapAxis
12
from gammapy.utils.scripts import make_path
13
from gammapy.utils.testing import assert_time_allclose, mpl_plot_check, requires_data
14
from gammapy.utils.time import time_ref_to_dict
15
16
MAP_AXIS_INTERP = [
17
    (np.array([0.25, 0.75, 1.0, 2.0]), "lin"),
18
    (np.array([0.25, 0.75, 1.0, 2.0]), "log"),
19
    (np.array([0.25, 0.75, 1.0, 2.0]), "sqrt"),
20
]
21
22
MAP_AXIS_NODE_TYPES = [
23
    ([0.25, 0.75, 1.0, 2.0], "lin", "edges"),
24
    ([0.25, 0.75, 1.0, 2.0], "log", "edges"),
25
    ([0.25, 0.75, 1.0, 2.0], "sqrt", "edges"),
26
    ([0.25, 0.75, 1.0, 2.0], "lin", "center"),
27
    ([0.25, 0.75, 1.0, 2.0], "log", "center"),
28
    ([0.25, 0.75, 1.0, 2.0], "sqrt", "center"),
29
]
30
31
32
nodes_array = np.array([0.25, 0.75, 1.0, 2.0])
33
34
MAP_AXIS_NODE_TYPE_UNIT = [
35
    (nodes_array, "lin", "edges", "s", "TEST", True),
36
    (nodes_array, "log", "edges", "s", "test", False),
37
    (nodes_array, "lin", "edges", "TeV", "TEST", False),
38
    (nodes_array, "sqrt", "edges", "s", "test", False),
39
    (nodes_array, "lin", "center", "s", "test", False),
40
    (nodes_array + 1e-9, "lin", "edges", "s", "test", True),
41
    (nodes_array + 1e-3, "lin", "edges", "s", "test", False),
42
    (nodes_array / 3600.0, "lin", "edges", "hr", "TEST", True),
43
]
44
45
46
@pytest.fixture
47
def time_intervals():
48
    t0 = Time("2020-03-19")
49
    t_min = np.linspace(0, 10, 20) * u.d
50
    t_max = t_min + 1 * u.h
51
    return {"t_min": t_min, "t_max": t_max, "t_ref": t0}
52
53
54
@pytest.fixture
55
def time_interval():
56
    t0 = Time("2020-03-19")
57
    t_min = 1 * u.d
58
    t_max = 11 * u.d
59
    return {"t_min": t_min, "t_max": t_max, "t_ref": t0}
60
61
62
@pytest.fixture(scope="session")
63
def energy_axis_ref():
64
    edges = np.arange(1, 11) * u.TeV
65
    return MapAxis.from_edges(edges, name="energy")
66
67
68
def test_mapaxis_repr():
69
    axis = MapAxis([1, 2, 3], name="test")
70
    assert "MapAxis" in repr(axis)
71
72
73
def test_mapaxis_invalid_name():
74
    with pytest.raises(TypeError):
75
        MapAxis([1, 2, 3], name=1)
76
77
78
@pytest.mark.parametrize(
79
    ("nodes", "interp", "node_type", "unit", "name", "result"),
80
    MAP_AXIS_NODE_TYPE_UNIT,
81
)
82
def test_mapaxis_equal(nodes, interp, node_type, unit, name, result):
83
    axis1 = MapAxis(
84
        nodes=[0.25, 0.75, 1.0, 2.0],
85
        name="test",
86
        unit="s",
87
        interp="lin",
88
        node_type="edges",
89
    )
90
91
    axis2 = MapAxis(nodes, name=name, unit=unit, interp=interp, node_type=node_type)
92
93
    assert (axis1 == axis2) is result
94
    assert (axis1 != axis2) is not result
95
96
97
def test_squash():
98
    axis = MapAxis(
99
        nodes=[0, 1, 2, 3], unit="TeV", name="energy", node_type="edges", interp="lin"
100
    )
101
    ax_sq = axis.squash()
102
103
    assert_allclose(ax_sq.nbin, 1)
104
    assert_allclose(axis.edges[0], ax_sq.edges[0])
105
    assert_allclose(axis.edges[-1], ax_sq.edges[1])
106
    assert_allclose(ax_sq.center, 1.5 * u.TeV)
107
108
109
def test_upsample():
110
    axis = MapAxis(
111
        nodes=[0, 1, 2, 3], unit="TeV", name="energy", node_type="edges", interp="lin"
112
    )
113
    axis_up = axis.upsample(10)
114
115
    assert_allclose(axis_up.nbin, 10 * axis.nbin)
116
    assert_allclose(axis_up.edges[0], axis.edges[0])
117
    assert_allclose(axis_up.edges[-1], axis.edges[-1])
118
    assert axis_up.node_type == axis.node_type
119
120
121
def test_downsample():
122
    axis = MapAxis(
123
        nodes=[0, 1, 2, 3, 4, 5, 6, 7, 8],
124
        unit="TeV",
125
        name="energy",
126
        node_type="edges",
127
        interp="lin",
128
    )
129
    axis_down = axis.downsample(2)
130
131
    assert_allclose(axis_down.nbin, 0.5 * axis.nbin)
132
    assert_allclose(axis_down.edges[0], axis.edges[0])
133
    assert_allclose(axis_down.edges[-1], axis.edges[-1])
134
    assert axis_down.node_type == axis.node_type
135
136
137
def test_upsample_non_regular():
138
    axis = MapAxis.from_edges([0, 1, 3, 7], name="test", interp="lin")
139
    axis_up = axis.upsample(2)
140
141
    assert_allclose(axis_up.nbin, 2 * axis.nbin)
142
    assert_allclose(axis_up.edges[0], axis.edges[0])
143
    assert_allclose(axis_up.edges[-1], axis.edges[-1])
144
    assert axis_up.node_type == axis.node_type
145
146
147
def test_upsample_non_regular_nodes():
148
    axis = MapAxis.from_nodes([0, 1, 3, 7], name="test", interp="lin")
149
    axis_up = axis.upsample(2)
150
151
    assert_allclose(axis_up.nbin, 2 * axis.nbin - 1)
152
    assert_allclose(axis_up.center[0], axis.center[0])
153
    assert_allclose(axis_up.center[-1], axis.center[-1])
154
    assert axis_up.node_type == axis.node_type
155
156
157
def test_downsample_non_regular():
158
    axis = MapAxis.from_edges([0, 1, 3, 7, 13], name="test", interp="lin")
159
    axis_down = axis.downsample(2)
160
161
    assert_allclose(axis_down.nbin, 0.5 * axis.nbin)
162
    assert_allclose(axis_down.edges[0], axis.edges[0])
163
    assert_allclose(axis_down.edges[-1], axis.edges[-1])
164
    assert axis_down.node_type == axis.node_type
165
166
167
def test_downsample_non_regular_nodes():
168
    axis = MapAxis.from_edges([0, 1, 3, 7, 9], name="test", interp="lin")
169
    axis_down = axis.downsample(2)
170
171
    assert_allclose(axis_down.nbin, 0.5 * axis.nbin)
172
    assert_allclose(axis_down.edges[0], axis.edges[0])
173
    assert_allclose(axis_down.edges[-1], axis.edges[-1])
174
    assert axis_down.node_type == axis.node_type
175
176
177
@pytest.mark.parametrize("factor", [1, 3, 5, 7, 11])
178
def test_up_downsample_consistency(factor):
179
    axis = MapAxis.from_edges([0, 1, 3, 7, 13], name="test", interp="lin")
180
    axis_new = axis.upsample(factor).downsample(factor)
181
    assert_allclose(axis.edges, axis_new.edges)
182
183
184
def test_one_bin_nodes():
185
    axis = MapAxis.from_nodes([1], name="test", unit="deg")
186
187
    assert_allclose(axis.center, 1 * u.deg)
188
    assert_allclose(axis.coord_to_pix(1 * u.deg), 0)
189
    assert_allclose(axis.coord_to_pix(2 * u.deg), 0)
190
    assert_allclose(axis.pix_to_coord(0), 1 * u.deg)
191
192
193 View Code Duplication
def test_group_table_basic(energy_axis_ref):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
194
    energy_edges = [1, 2, 10] * u.TeV
195
196
    groups = energy_axis_ref.group_table(energy_edges)
197
198
    assert_allclose(groups["group_idx"], [0, 1])
199
    assert_allclose(groups["idx_min"], [0, 1])
200
    assert_allclose(groups["idx_max"], [0, 8])
201
    assert_allclose(groups["energy_min"], [1, 2])
202
    assert_allclose(groups["energy_max"], [2, 10])
203
204
    bin_type = [_.strip() for _ in groups["bin_type"]]
205
    assert_equal(bin_type, ["normal", "normal"])
206
207
208
@pytest.mark.parametrize(
209
    "energy_edges",
210
    [[1.8, 4.8, 7.2] * u.TeV, [2, 5, 7] * u.TeV, [2000, 5000, 7000] * u.GeV],
211
)
212
def test_group_tablenergy_edges(energy_axis_ref, energy_edges):
213
    groups = energy_axis_ref.group_table(energy_edges)
214
215
    assert_allclose(groups["group_idx"], [0, 1, 2, 3])
216
    assert_allclose(groups["idx_min"], [0, 1, 4, 6])
217
    assert_allclose(groups["idx_max"], [0, 3, 5, 8])
218
    assert_allclose(groups["energy_min"].quantity.to_value("TeV"), [1, 2, 5, 7])
219
    assert_allclose(groups["energy_max"].quantity.to_value("TeV"), [2, 5, 7, 10])
220
221
    bin_type = [_.strip() for _ in groups["bin_type"]]
222
    assert_equal(bin_type, ["underflow", "normal", "normal", "overflow"])
223
224
225 View Code Duplication
def test_group_table_below_range(energy_axis_ref):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
226
    energy_edges = [0.7, 0.8, 1, 4] * u.TeV
227
    groups = energy_axis_ref.group_table(energy_edges)
228
229
    assert_allclose(groups["group_idx"], [0, 1])
230
    assert_allclose(groups["idx_min"], [0, 3])
231
    assert_allclose(groups["idx_max"], [2, 8])
232
    assert_allclose(groups["energy_min"], [1, 4])
233
    assert_allclose(groups["energy_max"], [4, 10])
234
235
    bin_type = [_.strip() for _ in groups["bin_type"]]
236
    assert_equal(bin_type, ["normal", "overflow"])
237
238
239 View Code Duplication
def test_group_table_above_range(energy_axis_ref):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
240
    energy_edges = [5, 7, 11, 13] * u.TeV
241
    groups = energy_axis_ref.group_table(energy_edges)
242
243
    assert_allclose(groups["group_idx"], [0, 1, 2])
244
    assert_allclose(groups["idx_min"], [0, 4, 6])
245
    assert_allclose(groups["idx_max"], [3, 5, 8])
246
    assert_allclose(groups["energy_min"], [1, 5, 7])
247
    assert_allclose(groups["energy_max"], [5, 7, 10])
248
249
    bin_type = [_.strip() for _ in groups["bin_type"]]
250
    assert_equal(bin_type, ["underflow", "normal", "normal"])
251
252
253
def test_group_table_outside_range(energy_axis_ref):
254
    energy_edges = [20, 30, 40] * u.TeV
255
256
    with pytest.raises(ValueError):
257
        energy_axis_ref.group_table(energy_edges)
258
259
260
def test_map_axis_aligned():
261
    ax1 = MapAxis([1, 2, 3], interp="lin", node_type="edges")
262
    ax2 = MapAxis([1.5, 2.5], interp="log", node_type="center")
263
    assert not ax1.is_aligned(ax2)
264
265
266
def test_map_axis_pad():
267
    axis = MapAxis.from_energy_bounds("1 TeV", "10 TeV", nbin=1)
268
269
    padded = axis.pad(pad_width=(0, 1))
270
    assert_allclose(padded.edges, [1, 10, 100] * u.TeV)
271
272
    padded = axis.pad(pad_width=(1, 0))
273
    assert_allclose(padded.edges, [0.1, 1, 10] * u.TeV)
274
275
    padded = axis.pad(pad_width=1)
276
    assert_allclose(padded.edges, [0.1, 1, 10, 100] * u.TeV)
277
278
279
def test_map_axes_pad():
280
    axis_1 = MapAxis.from_energy_bounds("1 TeV", "10 TeV", nbin=1)
281
    axis_2 = MapAxis.from_bounds(0, 1, nbin=2, unit="deg", name="rad")
282
283
    axes = MapAxes([axis_1, axis_2])
284
285
    axes = axes.pad(axis_name="energy", pad_width=1)
286
287
    assert_allclose(axes["energy"].edges, [0.1, 1, 10, 100] * u.TeV)
288
289
290
def test_rename():
291
    axis_1 = MapAxis.from_energy_bounds("1 TeV", "10 TeV", nbin=1)
292
    axis = axis_1.rename("energy_true")
293
    assert axis_1.name == "energy"
294
    assert axis.name == "energy_true"
295
296
    axis_2 = MapAxis.from_bounds(0, 1, nbin=2, unit="deg", name="rad")
297
298
    axes = MapAxes([axis_1, axis_2])
299
    axes = axes.rename_axes(["energy", "rad"], ["energy_true", "angle"])
300
    assert axes.names == ["energy_true", "angle"]
301
302
303
@pytest.mark.parametrize(("edges", "interp"), MAP_AXIS_INTERP)
304
def test_mapaxis_init_from_edges(edges, interp):
305
    axis = MapAxis(edges, interp=interp)
306
    assert_allclose(axis.edges, edges)
307
    assert_allclose(axis.nbin, len(edges) - 1)
308
    with pytest.raises(ValueError):
309
        MapAxis.from_edges([1])
310
        MapAxis.from_edges([0, 1, 1, 2])
311
        MapAxis.from_edges([0, 1, 3, 2])
312
313
314
@pytest.mark.parametrize(("nodes", "interp"), MAP_AXIS_INTERP)
315
def test_mapaxis_from_nodes(nodes, interp):
316
    axis = MapAxis.from_nodes(nodes, interp=interp)
317
    assert_allclose(axis.center, nodes)
318
    assert_allclose(axis.nbin, len(nodes))
319
    with pytest.raises(ValueError):
320
        MapAxis.from_nodes([])
321
        MapAxis.from_nodes([0, 1, 1, 2])
322
        MapAxis.from_nodes([0, 1, 3, 2])
323
324
325
@pytest.mark.parametrize(("nodes", "interp"), MAP_AXIS_INTERP)
326
def test_mapaxis_from_bounds(nodes, interp):
327
    axis = MapAxis.from_bounds(nodes[0], nodes[-1], 3, interp=interp)
328
    assert_allclose(axis.edges[0], nodes[0])
329
    assert_allclose(axis.edges[-1], nodes[-1])
330
    assert_allclose(axis.nbin, 3)
331
    with pytest.raises(ValueError):
332
        MapAxis.from_bounds(1, 1, 1)
333
334
335
def test_map_axis_from_energy_units():
336
    with pytest.raises(ValueError):
337
        _ = MapAxis.from_energy_bounds(0.1, 10, 2, unit="deg")
338
339
    with pytest.raises(ValueError):
340
        _ = MapAxis.from_energy_edges([0.1, 1, 10] * u.deg)
341
342
343
@pytest.mark.parametrize(("nodes", "interp", "node_type"), MAP_AXIS_NODE_TYPES)
344
def test_mapaxis_pix_to_coord(nodes, interp, node_type):
345
    axis = MapAxis(nodes, interp=interp, node_type=node_type)
346
    assert_allclose(axis.center, axis.pix_to_coord(np.arange(axis.nbin, dtype=float)))
347
    assert_allclose(
348
        np.arange(axis.nbin + 1, dtype=float) - 0.5, axis.coord_to_pix(axis.edges)
349
    )
350
351
352
@pytest.mark.parametrize(("nodes", "interp", "node_type"), MAP_AXIS_NODE_TYPES)
353
def test_mapaxis_coord_to_idx(nodes, interp, node_type):
354
    axis = MapAxis(nodes, interp=interp, node_type=node_type)
355
    assert_allclose(np.arange(axis.nbin, dtype=int), axis.coord_to_idx(axis.center))
356
357
358
@pytest.mark.parametrize(("nodes", "interp", "node_type"), MAP_AXIS_NODE_TYPES)
359
def test_mapaxis_slice(nodes, interp, node_type):
360
    axis = MapAxis(nodes, interp=interp, node_type=node_type)
361
    saxis = axis.slice(slice(1, 3))
362
    assert_allclose(saxis.nbin, 2)
363
    assert_allclose(saxis.center, axis.center[slice(1, 3)])
364
365
    axis = MapAxis(nodes, interp=interp, node_type=node_type)
366
    saxis = axis.slice(slice(1, None))
367
    assert_allclose(saxis.nbin, axis.nbin - 1)
368
    assert_allclose(saxis.center, axis.center[slice(1, None)])
369
370
    axis = MapAxis(nodes, interp=interp, node_type=node_type)
371
    saxis = axis.slice(slice(None, 2))
372
    assert_allclose(saxis.nbin, 2)
373
    assert_allclose(saxis.center, axis.center[slice(None, 2)])
374
375
    axis = MapAxis(nodes, interp=interp, node_type=node_type)
376
    saxis = axis.slice(slice(None, -1))
377
    assert_allclose(saxis.nbin, axis.nbin - 1)
378
    assert_allclose(saxis.center, axis.center[slice(None, -1)])
379
380
381
def test_map_axis_plot_helpers():
382
    axis = MapAxis.from_nodes([0, 1, 2], unit="deg", name="offset")
383
    labels = axis.as_plot_labels
384
385
    assert labels[0] == "0.00e+00 deg"
386
387
    assert_allclose(axis.center, axis.as_plot_center)
388
    assert_allclose(axis.edges, axis.as_plot_edges)
389
390
391
def test_time_axis(time_intervals):
392
    axis = TimeMapAxis(
393
        time_intervals["t_min"], time_intervals["t_max"], time_intervals["t_ref"]
394
    )
395
396
    axis_copy = axis.copy()
397
398
    assert axis.nbin == 20
399
    assert axis.name == "time"
400
    assert axis.node_type == "intervals"
401
402
    assert_allclose(axis.time_delta.to_value("min"), 60)
403
    assert_allclose(axis.time_mid[0].mjd, 58927.020833333336)
404
405
    assert "time" in axis.__str__()
406
    assert "20" in axis.__str__()
407
408
    with pytest.raises(ValueError):
409
        axis.assert_name("bad")
410
411
    assert axis_copy == axis
412
413
    assert not axis.is_contiguous
414
415
    ax_cont = axis.to_contiguous()
416
    assert_allclose(ax_cont.nbin, 39)
417
418
419
def test_single_interval_time_axis(time_interval):
420
    axis = TimeMapAxis(
421
        edges_min=time_interval["t_min"],
422
        edges_max=time_interval["t_max"],
423
        reference_time=time_interval["t_ref"],
424
    )
425
426
    coord = Time(58933, format="mjd") + u.Quantity([1.5, 3.5, 10], unit="d")
427
    pix = axis.coord_to_pix(coord)
428
429
    assert axis.nbin == 1
430
    assert_allclose(axis.time_delta.to_value("d"), 10)
431
    assert_allclose(axis.time_mid[0].mjd, 58933)
432
433
    pix_min = axis.coord_to_pix(time_interval["t_min"] + 0.001 * u.s)
434
    assert_allclose(pix_min, -0.5)
435
436
    pix_max = axis.coord_to_pix(time_interval["t_max"] - 0.001 * u.s)
437
    assert_allclose(pix_max, 0.5)
438
439
    assert_allclose(pix, [0.15, 0.35, np.nan])
440
441
442
def test_slice_squash_time_axis(time_intervals):
443
    axis = TimeMapAxis(
444
        time_intervals["t_min"], time_intervals["t_max"], time_intervals["t_ref"]
445
    )
446
    axis_squash = axis.squash()
447
    axis_slice = axis.slice(slice(1, 5))
448
449
    assert axis_squash.nbin == 1
450
    assert_allclose(axis_squash.time_min[0].mjd, 58927)
451
    assert_allclose(axis_squash.time_delta.to_value("d"), 10.04166666)
452
    assert axis_slice.nbin == 4
453
    assert_allclose(axis_slice.time_delta.to_value("d")[0], 0.04166666666)
454
    assert axis_squash != axis_slice
455
456
457
def test_from_time_edges_time_axis():
458
    t0 = Time("2020-03-19")
459
    t_min = t0 + np.linspace(0, 10, 20) * u.d
460
    t_max = t_min + 1 * u.h
461
462
    axis = TimeMapAxis.from_time_edges(t_min, t_max)
463
    axis_h = TimeMapAxis.from_time_edges(t_min, t_max, unit="h")
464
465
    assert axis.nbin == 20
466
    assert axis.name == "time"
467
    assert_time_allclose(axis.reference_time, t0)
468
    assert_allclose(axis.time_delta.to_value("min"), 60)
469
    assert_allclose(axis.time_mid[0].mjd, 58927.020833333336)
470
    assert_allclose(axis_h.time_delta.to_value("h"), 1)
471
    assert_allclose(axis_h.time_mid[0].mjd, 58927.020833333336)
472
    assert axis == axis_h
473
474
475
def test_incorrect_time_axis():
476
    tmin = np.linspace(0, 10, 20) * u.h
477
    tmax = np.linspace(1, 11, 20) * u.h
478
479
    # incorrect reference time
480
    with pytest.raises(ValueError):
481
        TimeMapAxis(tmin, tmax, reference_time=51000 * u.d, name="time")
482
483
    # overlapping time intervals
484
    with pytest.raises(ValueError):
485
        TimeMapAxis(tmin, tmax, reference_time=Time.now(), name="time")
486
487
488
def test_bad_length_sort_time_axis(time_intervals):
489
    tref = time_intervals["t_ref"]
490
    tmin = time_intervals["t_min"]
491
    tmax_reverse = time_intervals["t_max"][::-1]
492
    tmax_short = time_intervals["t_max"][:-1]
493
494
    with pytest.raises(ValueError):
495
        TimeMapAxis(tmin, tmax_reverse, tref, name="time")
496
497
    with pytest.raises(ValueError):
498
        TimeMapAxis(tmin, tmax_short, tref, name="time")
499
500
501
def test_coord_to_idx_time_axis(time_intervals):
502
    tmin = time_intervals["t_min"]
503
    tmax = time_intervals["t_max"]
504
    tref = time_intervals["t_ref"]
505
    axis = TimeMapAxis(tmin, tmax, tref, name="time")
506
507
    time = Time(58927.020833333336, format="mjd")
508
509
    times = axis.time_mid
510
    times[::2] += 1 * u.h
511
    times = times.insert(0, tref - [1, 2] * u.yr)
512
513
    idx = axis.coord_to_idx(time)
514
    indices = axis.coord_to_idx(times)
515
516
    pix = axis.coord_to_pix(time)
517
    pixels = axis.coord_to_pix(times)
518
519
    assert idx == 0
520
    assert_allclose(indices[1::2], [-1, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19])
521
    assert_allclose(indices[::2], -1)
522
    assert_allclose(pix, 0, atol=1e-10)
523
    assert_allclose(pixels[1::2], [np.nan, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19])
524
525
526
def test_slice_time_axis(time_intervals):
527
    axis = TimeMapAxis(
528
        time_intervals["t_min"], time_intervals["t_max"], time_intervals["t_ref"]
529
    )
530
531
    new_axis = axis.slice([2, 6, 9])
532
    squashed = axis.squash()
533
534
    assert new_axis.nbin == 3
535
    assert_allclose(squashed.time_max[0].mjd, 58937.041667)
536
    assert squashed.nbin == 1
537
    assert_allclose(squashed.time_max[0].mjd, 58937.041667)
538
539
540
def test_time_map_axis_from_time_bounds():
541
    t_min = Time("2006-02-12", scale="utc")
542
    t_max = t_min + 12 * u.h
543
544
    axis = TimeMapAxis.from_time_bounds(time_min=t_min, time_max=t_max, nbin=3)
545
    assert_allclose(axis.center, [0.083333, 0.25, 0.416667] * u.d, rtol=1e-5)
546
547
548
def test_from_table_time_axis():
549
    t0 = Time("2006-02-12", scale="utc")
550
    t_min = np.linspace(0, 10, 10) * u.d
551
    t_max = t_min + 12 * u.h
552
553
    table = Table()
554
    table["TIME_MIN"] = t_min
555
    table["TIME_MAX"] = t_max
556
    table.meta.update(time_ref_to_dict(t0))
557
    table.meta["AXCOLS1"] = "TIME_MIN,TIME_MAX"
558
559
    axis = TimeMapAxis.from_table(table, format="gadf")
560
561
    assert axis.nbin == 10
562
    assert_allclose(axis.time_mid[0].mjd, 53778.25)
563
564
565
@requires_data()
566
def test_from_gti_time_axis():
567
    filename = "$GAMMAPY_DATA/hess-dl3-dr1/data/hess_dl3_dr1_obs_id_020136.fits.gz"
568
    filename = make_path(filename)
569
    gti = GTI.read(filename)
570
571
    axis = TimeMapAxis.from_gti(gti)
572
    expected = Time(53090.123451203704, format="mjd", scale="tt")
573
    assert_time_allclose(axis.time_min[0], expected)
574
    assert axis.nbin == 1
575
576
577
def test_map_with_time_axis(time_intervals):
578
    time_axis = TimeMapAxis(
579
        time_intervals["t_min"], time_intervals["t_max"], time_intervals["t_ref"]
580
    )
581
    energy_axis = MapAxis.from_energy_bounds(0.1, 10, 2, unit="TeV")
582
    region_map = RegionNDMap.create(
583
        region="fk5; circle(0,0,0.1)", axes=[energy_axis, time_axis]
584
    )
585
586
    assert region_map.geom.data_shape == (20, 2, 1, 1)
587
588
589
def test_time_axis_plot_helpers():
590
    time_ref = Time("1999-01-01T00:00:00.123456789")
591
592
    time_axis = TimeMapAxis(
593
        edges_min=[0, 1, 3] * u.d,
594
        edges_max=[0.8, 1.9, 5.4] * u.d,
595
        reference_time=time_ref,
596
    )
597
598
    labels = time_axis.as_plot_labels
599
    assert labels[0] == "1999-01-01 00:00:00.123 - 1999-01-01 19:12:00.123"
600
601
    center = time_axis.as_plot_center
602
    assert center[0].year == 1999
603
604
    edges = time_axis.to_contiguous().as_plot_edges
605
    assert edges[0].year == 1999
606
607
608
def test_axes_basics():
609
    energy_axis = MapAxis.from_energy_edges([1, 3] * u.TeV)
610
611
    time_ref = Time("1999-01-01T00:00:00.123456789")
612
613
    time_axis = TimeMapAxis(
614
        edges_min=[0, 1, 3] * u.d,
615
        edges_max=[0.8, 1.9, 5.4] * u.d,
616
        reference_time=time_ref,
617
    )
618
619
    axes = MapAxes([energy_axis, time_axis])
620
621
    assert axes.shape == (1, 3)
622
    assert axes.is_unidimensional
623
    assert not axes.is_flat
624
625
    assert axes.primary_axis.name == "time"
626
627
    new_axes = axes.copy()
628
    assert new_axes[0] == new_axes[0]
629
    assert new_axes[1] == new_axes[1]
630
    assert new_axes == axes
631
632
    energy_axis = MapAxis.from_energy_edges([1, 4] * u.TeV)
633
    new_axes = MapAxes([energy_axis, time_axis.copy()])
634
    assert new_axes != axes
635
636
637
def test_axes_getitem():
638
    axis1 = MapAxis.from_bounds(1, 4, 3, name="a1")
639
    axis2 = axis1.copy(name="a2")
640
    axis3 = axis1.copy(name="a3")
641
    axes = MapAxes([axis1, axis2, axis3])
642
643
    assert isinstance(axes[0], MapAxis)
644
    assert axes[-1].name == "a3"
645
    assert isinstance(axes[1:], MapAxes)
646
    assert len(axes[1:]) == 2
647
    assert isinstance(axes[0:1], MapAxes)
648
    assert len(axes[0:1]) == 1
649
    assert isinstance(axes[["a3", "a1"]], MapAxes)
650
    assert axes[["a3", "a1"]][0].name == "a3"
651
652
653
def test_label_map_axis_basics():
654
    axis = LabelMapAxis(labels=["label-1", "label-2"], name="label-axis")
655
656
    axis_str = str(axis)
657
    assert "node type" in axis_str
658
    assert "labels" in axis_str
659
    assert "label-2" in axis_str
660
661
    with pytest.raises(ValueError):
662
        axis.assert_name("time")
663
664
    assert axis.nbin == 2
665
    assert axis.node_type == "label"
666
667
    assert_allclose(axis.bin_width, 1)
668
669
    assert axis.name == "label-axis"
670
671
    with pytest.raises(ValueError):
672
        axis.edges
673
674
    axis_copy = axis.copy()
675
    assert axis_copy.name == "label-axis"
676
677
678
def test_label_map_axis_coord_to_idx():
679
    axis = LabelMapAxis(labels=["label-1", "label-2", "label-3"], name="label-axis")
680
681
    labels = "label-1"
682
    idx = axis.coord_to_idx(coord=labels)
683
    assert_allclose(idx, 0)
684
685
    labels = ["label-1", "label-3"]
686
    idx = axis.coord_to_idx(coord=labels)
687
    assert_allclose(idx, [0, 2])
688
689
    labels = [["label-1"], ["label-2"]]
690
    idx = axis.coord_to_idx(coord=labels)
691
    assert_allclose(idx, [[0], [1]])
692
693
    with pytest.raises(ValueError):
694
        labels = [["bad-label"], ["label-2"]]
695
        _ = axis.coord_to_idx(coord=labels)
696
697
698
def test_mixed_axes():
699
    label_axis = LabelMapAxis(labels=["label-1", "label-2", "label-3"], name="label")
700
701
    time_axis = TimeMapAxis(
702
        edges_min=[1, 10] * u.day,
703
        edges_max=[2, 13] * u.day,
704
        reference_time=Time("2020-03-19"),
705
    )
706
707
    energy_axis = MapAxis.from_energy_bounds("1 TeV", "10 TeV", nbin=4)
708
709
    axes = MapAxes(axes=[energy_axis, time_axis, label_axis])
710
711
    coords = axes.get_coord()
712
713
    assert coords["label"].shape == (1, 1, 3)
714
    assert coords["energy"].shape == (4, 1, 1)
715
    assert coords["time"].shape == (1, 2, 1)
716
717
    idx = axes.coord_to_idx(coords)
718
719
    assert_allclose(idx[0], np.arange(4).reshape((4, 1, 1)))
720
    assert_allclose(idx[1], np.arange(2).reshape((1, 2, 1)))
721
    assert_allclose(idx[2], np.arange(3).reshape((1, 1, 3)))
722
723
    hdu = axes.to_table_hdu(format="gadf")
724
725
    table = Table.read(hdu)
726
727
    assert table["LABEL"].dtype == np.dtype("U7")
728
    assert len(table) == 24
729
730
731
def test_map_axis_format_plot_xaxis():
732
    axis = MapAxis.from_energy_bounds(
733
        "0.03 TeV", "300 TeV", nbin=20, per_decade=True, name="energy_true"
734
    )
735
736
    with mpl_plot_check():
737
        ax = plt.gca()
738
        with quantity_support():
739
            ax.plot(axis.center, np.ones_like(axis.center))
740
741
    ax1 = axis.format_plot_xaxis(ax=ax)
742
    assert ax1.xaxis.label.properties()["text"] == "True Energy [TeV]"
743
744
745
def test_time_map_axis_format_plot_xaxis(time_intervals):
746
    axis = TimeMapAxis(
747
        time_intervals["t_min"],
748
        time_intervals["t_max"],
749
        time_intervals["t_ref"],
750
        name="time",
751
    )
752
753
    with mpl_plot_check():
754
        ax = plt.gca()
755
        with quantity_support():
756
            ax.plot(axis.center, np.ones_like(axis.center))
757
758
    ax1 = axis.format_plot_xaxis(ax=ax)
759
    assert ax1.xaxis.label.properties()["text"] == "Time [iso]"
760