Passed
Pull Request — main (#204)
by Chaitanya
01:30
created

asgardpy.stats.stats.get_goodness_of_fit_stats()   A

Complexity

Conditions 1

Size

Total Lines 54
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 15
nop 2
dl 0
loc 54
rs 9.65
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Module for performing some statistic functions.
3
"""
4
5
import numpy as np
6
from gammapy.stats.fit_statistics import cash, wstat
7
from scipy.stats import chi2, norm
8
9
__all__ = [
10
    "check_model_preference_aic",
11
    "check_model_preference_lrt",
12
    "fetch_pivot_energy",
13
    "get_chi2_sig_pval",
14
    "get_goodness_of_fit_stats",
15
    "get_ts_target",
16
]
17
18
19
def get_chi2_sig_pval(test_stat, ndof):
20
    """
21
    Using the log-likelihood value for a model fitting to data, along with the
22
    total degrees of freedom, evaluate the significance value in terms of gaussian
23
    distribution along with one-tailed p-value for the fitting statistics.
24
25
    In Gammapy, for 3D analysis, cash statistics is used, while for 1D analysis,
26
    wstat statistics is used. Check the documentation for more details
27
    https://docs.gammapy.org/1.3/user-guide/stats/index.html
28
29
    Parameters
30
    ----------
31
    test_stat: float
32
        The test statistic (-2 ln L) value of the fitting.
33
    ndof: int
34
        Total number of degrees of freedom.
35
36
    Returns
37
    -------
38
    chi2_sig: float
39
        significance (Chi2) of the likelihood of fit model estimated in
40
        Gaussian distribution.
41
    pval: float
42
        p-value for the model fitting
43
44
    """
45
    pval = chi2.sf(test_stat, ndof)
46
    chi2_sig = norm.isf(pval / 2)
47
48
    return chi2_sig, pval
49
50
51
def check_model_preference_lrt(test_stat_1, test_stat_2, ndof_1, ndof_2):
52
    """
53
    Log-likelihood ratio test. Checking the preference of a "nested" spectral
54
    model2 (observed), over a primary model1.
55
56
    Parameters
57
    ----------
58
    test_stat_1: float
59
        The test statistic (-2 ln L) of the Fit result of the primary spectral model.
60
    test_stat_2: float
61
        The test statistic (-2 ln L) of the Fit result of the nested spectral model.
62
    ndof_1: int
63
        Number of degrees of freedom for the primary model
64
    ndof_2: int
65
        Number of degrees of freedom for the nested model
66
67
    Returns
68
    -------
69
    p_value: float
70
        p-value for the ratio of the likelihoods
71
    gaussian_sigmas: float
72
        significance (Chi2) of the ratio of the likelihoods estimated in
73
        Gaussian distribution.
74
    n_dof: int
75
        number of degrees of freedom or free parameters between primary and
76
        nested model.
77
    """
78
    n_dof = ndof_1 - ndof_2
79
80
    if n_dof < 1:
81
        print(f"DoF is lower in {ndof_1} compared to {ndof_2}")
82
83
        return np.nan, np.nan, n_dof
84
85
    gaussian_sigmas, p_value = get_chi2_sig_pval(test_stat_1 - test_stat_2, n_dof)
86
87
    return p_value, gaussian_sigmas, n_dof
88
89
90
def check_model_preference_aic(list_stat, list_dof):
91
    """
92
    Akaike Information Criterion (AIC) preference over a list of stat and DoF
93
    (degree of freedom) to get relative likelihood of a given list of best-fit
94
    models.
95
96
    Parameters
97
    ----------
98
    list_wstat: list
99
        List of stat or -2 Log likelihood values for a list of models.
100
    list_dof: list
101
        List of degrees of freedom or list of free parameters, for a list of models.
102
103
    Returns
104
    -------
105
    list_rel_p: list
106
        List of relative likelihood probabilities, for a list of models.
107
    """
108
    list_aic_stat = []
109
    for stat, dof in zip(list_stat, list_dof, strict=True):
110
        aic_stat = stat + 2 * dof
111
        list_aic_stat.append(aic_stat)
112
    list_aic_stat = np.array(list_aic_stat)
113
114
    aic_stat_min = np.min(list_aic_stat)
115
116
    list_b_stat = []
117
    for aic in list_aic_stat:
118
        b_stat = np.exp((aic_stat_min - aic) / 2)
119
        list_b_stat.append(b_stat)
120
    list_b_stat = np.array(list_b_stat)
121
122
    list_rel_p = []
123
    for b_stat in list_b_stat:
124
        rel_p = b_stat / np.sum(list_b_stat)
125
        list_rel_p.append(rel_p)
126
    list_rel_p = np.array(list_rel_p)
127
128
    return list_rel_p
129
130
131
def get_goodness_of_fit_stats(datasets, instrument_spectral_info):
132
    """
133
    Evaluating the Goodness of Fit statistics of the fitting of the model to
134
    the dataset.
135
136
    We first use the get_ts_target function to get the total test statistic for
137
    the (observed) best fit of the model to the data, and the (expected)
138
    perfect fit of model and data (model = data), for the given target source
139
    region/pixel.
140
141
    We then evaluate the total number of Degrees of Freedom for the Fit as the
142
    difference between the number of relevant energy bins used in the evaluation
143
    and the number of free model parameters.
144
145
    The fit statistics difference is used as the test statistic value for
146
    get_chi2_sig_pval function along with the total number of degrees of freedom
147
    to get the final statistics for the goodness of fit.
148
149
    The fit statistics information is updated in the dict object provided and
150
    a logging message is passed.
151
152
    Parameter
153
    ---------
154
    datasets: `gammapy.datasets.Datasets`
155
        List of Datasets object, which can contain 3D and/or 1D datasets
156
    instrument_spectral_info: dict
157
        Dict of information for storing relevant fit stats
158
159
    Return
160
    ------
161
    instrument_spectral_info: dict
162
        Filled Dict of information with relevant fit statistics
163
    stat_message: str
164
        String for logging the fit statistics
165
    """
166
    stat_best_fit, stat_max_fit = get_ts_target(datasets)
167
168
    instrument_spectral_info["max_fit_stat"] = stat_max_fit
169
    instrument_spectral_info["best_fit_stat"] = stat_best_fit
170
    ndof = instrument_spectral_info["DoF"]
171
    stat_diff_gof = stat_best_fit - stat_max_fit
172
173
    fit_chi2_sig, fit_pval = get_chi2_sig_pval(stat_diff_gof, ndof)
174
175
    instrument_spectral_info["fit_chi2_sig"] = fit_chi2_sig
176
    instrument_spectral_info["fit_pval"] = fit_pval
177
178
    stat_message = "The Chi2/dof value of the goodness of Fit is "
179
    stat_message += f"{stat_diff_gof:.2f}/{ndof}\nand the p-value is {fit_pval:.3e} "
180
    stat_message += f"and in Significance {fit_chi2_sig:.2f} sigmas"
181
    stat_message += f"\nwith best fit TS (Observed) as {stat_best_fit:.3f} "
182
    stat_message += f"and max fit TS (Expected) as {stat_max_fit:.3f}"
183
184
    return instrument_spectral_info, stat_message
185
186
187
def get_ts_target(datasets):
188
    """
189
    From a given list of DL4 datasets, with assumed associated models, estimate
190
    the total test statistic values, in the given target source region/pixel,
191
    for the (observed) best fit of the model to the data, and the (expected)
192
    perfect fit of model and data (model = data).
193
194
    For consistency in the evaluation of the statistic values, we will use the
195
    basic Fit Statistic functions in Gammapy for Poisson Data:
196
197
    * `cash <https://docs.gammapy.org/1.3/api/gammapy.stats.cash.html>`_
198
199
    * `wstat <https://docs.gammapy.org/1.3/api/gammapy.stats.wstat.html>`_
200
201
    For the different type of Statistics used in Gammapy for 3D/1D datasets,
202
    and for our use case of getting the best fit and perfect fit, we will pass
203
    the appropriate values, by adapting to the following methods,
204
205
    * Best Fit (Observed):
206
207
        * `Cash stat_array <https://docs.gammapy.org/1.3/api/gammapy.datasets.MapDataset.html#gammapy.datasets.MapDataset.stat_array # noqa>`_
208
209
        * `Wstat stat_array <https://docs.gammapy.org/1.3/api/gammapy.datasets.MapDatasetOnOff.html#gammapy.datasets.MapDatasetOnOff.stat_array # noqa>`_
210
211
    * Perfect Fit (Expected):
212
213
        * `Cash stat_max <https://docs.gammapy.org/1.3/api/gammapy.stats.CashCountsStatistic.html#gammapy.stats.CashCountsStatistic.stat_max # noqa>`_
214
215
        * `Wstat stat_max <https://docs.gammapy.org/1.3/api/gammapy.stats.WStatCountsStatistic.html#gammapy.stats.WStatCountsStatistic.stat_max # noqa>`_
216
217
    Parameter
218
    ---------
219
    datasets: `gammapy.datasets.Datasets`
220
        List of Datasets object, which can contain 3D and/or 1D datasets
221
222
    Return
223
    ------
224
    stat_best_fit: float
225
        Total sum of test statistic of the best fit of model to data, summed
226
        over all energy bins.
227
    stat_max_fit: float
228
        Test statistic difference of the perfect fit of model to data summed
229
        over all energy bins.
230
    """  # noqa
231
    stat_best_fit = 0
232
    stat_max_fit = 0
233
234
    for data in datasets:
235
        match data.stat_type:
236
            case "cash":
237
                # Assuming that the Counts Map is created with the target source as its center
238
                region = data.counts.geom.center_skydir
239
240
                counts_on = (data.counts.copy() * data.mask).get_spectrum(region).data
241
                mu_on = (data.npred() * data.mask).get_spectrum(region).data
242
243
                stat_best_fit += np.nansum(cash(n_on=counts_on, mu_on=mu_on).ravel())
244
                stat_max_fit += np.nansum(cash(n_on=counts_on, mu_on=counts_on).ravel())
245
246
            case "wstat":
247
                # Assuming that the Counts Map is created with the target source as its center
248
                region = data.counts.geom.center_skydir
249
250
                counts_on = (data.counts.copy() * data.mask).get_spectrum(region).data
251
                counts_off = np.nan_to_num((data.counts_off * data.mask).get_spectrum(region)).data
252
253
                # alpha is evaluated by acceptance ratios, and
254
                # Background is evaluated with given alpha and counts_off,
255
                # but for alpha to be of the same shape (in the target region),
256
                # it will be reevaluated
257
                bkg = np.nan_to_num((data.background * data.mask).get_spectrum(region))
258
259
                with np.errstate(invalid="ignore", divide="ignore"):
260
                    alpha = bkg / counts_off
261
                mu_signal = np.nan_to_num((data.npred_signal() * data.mask).get_spectrum(region)).data
262
                max_pred = counts_on - bkg
263
264
                stat_best_fit += np.nansum(wstat(n_on=counts_on, n_off=counts_off, alpha=alpha, mu_sig=mu_signal))
265
                stat_max_fit += np.nansum(wstat(n_on=counts_on, n_off=counts_off, alpha=alpha, mu_sig=max_pred))
266
267
            case "chi2":
268
                # For FluxxPointsDataset
269
                stat_best_fit += np.nansum(data.stat_array())
270
                stat_max_fit += len(data.data.dnde.data)
271
272
    return stat_best_fit, stat_max_fit
273
274
275
def fetch_pivot_energy(analysis):
276
    """
277
    Using an 'AsgardpyAnalysis' object to get the pivot energy for a given dataset
278
    and fit model, using the pivot_energy function.
279
280
    Returns
281
    -------
282
    pivot energy : `~astropy.units.Quantity`
283
        The energy at which the statistical error in the computed flux is smallest.
284
        If no minimum is found, NaN will be returned.
285
    """
286
    # Check if DL4 datasets are created, and if not, only run steps till Fit
287
    if len(analysis.datasets) == 0:
288
        steps = [step for step in analysis.config.general.steps if step != "flux-points"]
289
290
        analysis.run(steps)
291
    else:
292
        analysis.run(["fit"])
293
294
    analysis.get_correct_intrinsic_model()
295
    pivot = analysis.model_deabs.spectral_model.pivot_energy
296
297
    return pivot
298