Passed
Pull Request — master (#2111)
by
unknown
02:29
created

TestSimpleFit.test_wstat()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 12
nop 1
dl 0
loc 16
rs 9.8
c 0
b 0
f 0
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import pytest
3
from numpy.testing import assert_allclose
4
import astropy.units as u
5
import numpy as np
6
from ...utils.testing import requires_data, requires_dependency
7
from ...utils.random import get_random_state
8
from ...irf import EffectiveAreaTable, EnergyDispersion
9
from ...utils.fitting import Fit
10
from ..models import PowerLaw, ConstantModel, ExponentialCutoffPowerLaw
11
from ...spectrum import PHACountsSpectrum, SpectrumDatasetOnOff
12
13
14
class TestSpectrumDatasetOnOff:
15
    """ Test ON OFF SpectrumDataset"""
16
17
    def setup(self):
18
19
        etrue = np.logspace(-1, 1, 10) * u.TeV
20
        self.e_true = etrue
21
        ereco = np.logspace(-1, 1, 5) * u.TeV
22
        elo = ereco[:-1]
23
        ehi = ereco[1:]
24
25
        self.aeff = EffectiveAreaTable(etrue[:-1], etrue[1:], np.ones(9) * u.cm ** 2)
26
        self.edisp = EnergyDispersion.from_diagonal_response(etrue, ereco)
27
28
        self.on_counts = PHACountsSpectrum(
29
            elo, ehi, np.ones_like(elo), backscal=np.ones_like(elo)
30
        )
31
        self.off_counts = PHACountsSpectrum(
32
            elo, ehi, np.ones_like(elo) * 10, backscal=np.ones_like(elo) * 10
33
        )
34
35
        self.livetime = 1000 * u.s
36
37
    def test_init_no_model(self):
38
        dataset = SpectrumDatasetOnOff(
39
            counts_on=self.on_counts,
40
            counts_off=self.off_counts,
41
            aeff=self.aeff,
42
            edisp=self.edisp,
43
            livetime=self.livetime,
44
        )
45
46
        with pytest.raises(AttributeError):
47
            dataset.npred()
48
49
        with pytest.raises(AttributeError):
50
            print(dataset.parameters)
51
52
    def test_alpha(self):
53
        dataset = SpectrumDatasetOnOff(
54
            counts_on=self.on_counts,
55
            counts_off=self.off_counts,
56
            aeff=self.aeff,
57
            edisp=self.edisp,
58
            livetime=self.livetime,
59
        )
60
61
        assert dataset.alpha.shape == (4,)
62
        assert_allclose(dataset.alpha, 0.1)
63
64
    def test_data_shape(self):
65
        dataset = SpectrumDatasetOnOff(
66
            counts_on=self.on_counts,
67
            counts_off=self.off_counts,
68
            aeff=self.aeff,
69
            edisp=self.edisp,
70
            livetime=self.livetime,
71
        )
72
73
        assert dataset.data_shape == self.on_counts.data.data.shape
74
75
    def test_npred_no_edisp(self):
76
        const = 1 / u.TeV / u.cm ** 2 / u.s
77
        model = ConstantModel(const)
78
        livetime = 1 * u.s
79
        dataset = SpectrumDatasetOnOff(
80
            counts_on=self.on_counts,
81
            counts_off=self.off_counts,
82
            aeff=self.aeff,
83
            model=model,
84
            livetime=livetime,
85
        )
86
87
        expected = (
88
            self.aeff.data.data[0]
89
            * (self.aeff.energy.hi[-1] - self.aeff.energy.lo[0])
90
            * const
91
            * livetime
92
        )
93
94
        assert_allclose(dataset.npred().sum(), expected.value)
95
96
    def test_incorrect_mask(self):
97
        mask = np.ones(self.on_counts.data.data.shape, dtype="int")
98
99
        with pytest.raises(ValueError):
100
            SpectrumDatasetOnOff(
101
                counts_on=self.on_counts,
102
                counts_off=self.off_counts,
103
                aeff=self.aeff,
104
                edisp=self.edisp,
105
                livetime=self.livetime,
106
                mask=mask,
107
            )
108
109
110
@requires_dependency("iminuit")
111
class TestSimpleFit:
112
    """Test fit on counts spectra without any IRFs"""
113
114 View Code Duplication
    def setup(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
115
        self.nbins = 30
116
        binning = np.logspace(-1, 1, self.nbins + 1) * u.TeV
117
        self.source_model = PowerLaw(
118
            index=2, amplitude=1e5 / u.TeV, reference=0.1 * u.TeV
119
        )
120
        self.bkg_model = PowerLaw(index=3, amplitude=1e4 / u.TeV, reference=0.1 * u.TeV)
121
122
        self.alpha = 0.1
123
        random_state = get_random_state(23)
124
        npred = self.source_model.integral(binning[:-1], binning[1:])
125
        source_counts = random_state.poisson(npred)
126
        self.src = PHACountsSpectrum(
127
            energy_lo=binning[:-1],
128
            energy_hi=binning[1:],
129
            data=source_counts,
130
            backscal=1,
131
        )
132
        # Currently it's necessary to specify a lifetime
133
        self.src.livetime = 1 * u.s
134
135
        npred_bkg = self.bkg_model.integral(binning[:-1], binning[1:])
136
137
        bkg_counts = random_state.poisson(npred_bkg)
138
        off_counts = random_state.poisson(npred_bkg * 1.0 / self.alpha)
139
        self.bkg = PHACountsSpectrum(
140
            energy_lo=binning[:-1], energy_hi=binning[1:], data=bkg_counts
141
        )
142
        self.off = PHACountsSpectrum(
143
            energy_lo=binning[:-1],
144
            energy_hi=binning[1:],
145
            data=off_counts,
146
            backscal=1.0 / self.alpha,
147
        )
148
149
    def test_wstat(self):
150
        """WStat with on source and background spectrum"""
151
        on_vector = self.src.copy()
152
        on_vector.data.data += self.bkg.data.data
153
        obs = SpectrumDatasetOnOff(counts_on=on_vector, counts_off=self.off)
154
        obs.model = self.source_model
155
156
        self.source_model.parameters.index = 1.12
157
158
        fit = Fit(obs)
159
        result = fit.run()
160
        pars = self.source_model.parameters
161
162
        assert_allclose(pars["index"].value, 1.997342, rtol=1e-3)
163
        assert_allclose(pars["amplitude"].value, 100245.187067, rtol=1e-3)
164
        assert_allclose(result.total_stat, 30.022316, rtol=1e-3)
165
166
    def test_joint(self):
167
        """Test joint fit for obs with different energy binning"""
168
        on_vector = self.src.copy()
169
        on_vector.data.data += self.bkg.data.data
170
        obs1 = SpectrumDatasetOnOff(counts_on=on_vector, counts_off=self.off)
171
        obs1.model = self.source_model
172
173
        src_rebinned = self.src.rebin(2)
174
        bkg_rebinned = self.off.rebin(2)
175
        src_rebinned.data.data += self.bkg.rebin(2).data.data
176
177
        obs2 = SpectrumDatasetOnOff(counts_on=src_rebinned, counts_off=bkg_rebinned)
178
        obs2.model = self.source_model
179
180
        fit = Fit([obs1, obs2])
181
        fit.run()
182
        pars = self.source_model.parameters
183
        assert_allclose(pars["index"].value, 1.996456, rtol=1e-3)
184
185
186
@requires_data("gammapy-data")
187
@requires_dependency("iminuit")
188
class TestSpectralFit:
189
    """Test fit in astrophysical scenario"""
190
191
    def setup(self):
192
        path = "$GAMMAPY_DATA/joint-crab/spectra/hess/"
193
        obs1 = SpectrumDatasetOnOff.read(path + "pha_obs23523.fits")
194
        obs2 = SpectrumDatasetOnOff.read(path + "pha_obs23592.fits")
195
        self.obs_list = [obs1, obs2]
196
197
        self.pwl = PowerLaw(
198
            index=2, amplitude=1e-12 * u.Unit("cm-2 s-1 TeV-1"), reference=1 * u.TeV
199
        )
200
201
        self.ecpl = ExponentialCutoffPowerLaw(
202
            index=2,
203
            amplitude=1e-12 * u.Unit("cm-2 s-1 TeV-1"),
204
            reference=1 * u.TeV,
205
            lambda_=0.1 / u.TeV,
206
        )
207
208
        # Example fit for one observation
209
        self.obs_list[0].model = self.pwl
210
        self.fit = Fit(self.obs_list[0])
211
212
    def set_model(self, model):
213
        for obs in self.obs_list:
214
            obs.model = model
215
216
    @requires_dependency("iminuit")
217
    def test_basic_results(self):
218
        self.set_model(self.pwl)
219
        result = self.fit.run()
220
        pars = self.fit.datasets.parameters
221
222
        assert self.pwl is self.obs_list[0].model
223
224
        assert_allclose(result.total_stat, 38.343, rtol=1e-3)
225
        assert_allclose(pars["index"].value, 2.817, rtol=1e-3)
226
        assert pars["amplitude"].unit == "cm-2 s-1 TeV-1"
227
        assert_allclose(pars["amplitude"].value, 5.142e-11, rtol=1e-3)
228
        assert_allclose(self.obs_list[0].npred()[60], 0.6102, rtol=1e-3)
229
        pars.to_table()
230
231
    def test_basic_errors(self):
232
        self.set_model(self.pwl)
233
        self.fit.run()
234
        pars = self.fit.datasets.parameters
235
236
        assert_allclose(pars.error("index"), 0.1496, rtol=1e-3)
237
        assert_allclose(pars.error("amplitude"), 6.423e-12, rtol=1e-3)
238
        pars.to_table()
239
240
    def test_compound(self):
241
        model = self.pwl * 2
242
        self.set_model(model)
243
        fit = Fit(self.obs_list[0])
244
        fit.run()
245
        pars = fit.datasets.parameters
246
247
        assert_allclose(pars["index"].value, 2.8166, rtol=1e-3)
248
        p = pars["amplitude"]
249
        assert p.unit == "cm-2 s-1 TeV-1"
250
        assert_allclose(p.value, 5.0714e-12, rtol=1e-3)
251
252
    def test_ecpl_fit(self):
253
        self.set_model(self.ecpl)
254
        fit = Fit(self.obs_list[0])
255
        fit.run()
256
257
        actual = fit.datasets.parameters["lambda_"].quantity
258
        assert actual.unit == "TeV-1"
259
        assert_allclose(actual.value, 0.145215, rtol=1e-2)
260
261
    def test_joint_fit(self):
262
        self.set_model(self.pwl)
263
        fit = Fit(self.obs_list)
264
        fit.run()
265
        actual = fit.datasets.parameters["index"].value
266
        assert_allclose(actual, 2.7806, rtol=1e-3)
267
268
        actual = fit.datasets.parameters["amplitude"].quantity
269
        assert actual.unit == "cm-2 s-1 TeV-1"
270
        assert_allclose(actual.value, 5.200e-11, rtol=1e-3)
271