Passed
Pull Request — main (#176)
by Chaitanya
01:21
created

check_preferred_model.main()   C

Complexity

Conditions 11

Size

Total Lines 80
Code Lines 53

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
eloc 53
nop 0
dl 0
loc 80
rs 5.3181
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like check_preferred_model.main() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import argparse
2
import logging
3
from pathlib import Path
4
5
import numpy as np
6
import yaml
7
from astropy.table import QTable
8
9
from asgardpy.analysis import AsgardpyAnalysis
10
from asgardpy.config import AsgardpyConfig
11
from asgardpy.config.generator import CONFIG_PATH
12
from asgardpy.stats.stats import check_model_preference_aic, check_model_preference_lrt
13
14
log = logging.getLogger(__name__)
15
16
parser = argparse.ArgumentParser(description="Get preferred best-fit spectral model")
17
18
parser.add_argument(
19
    "--config",
20
    "-c",
21
    help="Path to the config file",
22
)
23
24
# fetch options of spec models to test from user, or use all available...
25
parser.add_argument("--ebl-scale-factor", "-e", help="Value of EBL Norm Scale Factor", default=1.0, type=float)
26
27
parser.add_argument(
28
    "--ebl-model-name",
29
    "-m",
30
    help="Name of EBL model as used by Gammapy",
31
    default="dominguez",
32
    type=str,
33
)
34
35
parser.add_argument(
36
    "--write-config",
37
    help="Boolean to write the best-fit model into a separate file.",
38
    default=True,
39
    type=bool,
40
)
41
42
43
def get_model_config_files(select_model_tags):
44
    """From the default model templates, select some."""
45
46
    spec_model_template_files = sorted(list(CONFIG_PATH.glob("model_templates/model_template*yaml")))
47
48
    spec_model_temp_files = []
49
50
    for p in spec_model_template_files:
51
        tag = p.name.split(".")[0].split("_")[-1]
52
53
        if tag in select_model_tags:
54
            spec_model_temp_files.append(p)
55
56
    spec_model_temp_files = np.array(spec_model_temp_files)
57
58
    return spec_model_temp_files
59
60
61
def update_config(config_1, config_2):
62
    """From config_1 update information in config_2."""
63
64
    # Have the same value of amplitude
65
    config_2.config.target.components[0].spectral.parameters[0].value = (
66
        config_1.config.target.components[0].spectral.parameters[0].value
67
    )
68
    # Have the same value of reference/e_break energy
69
    config_2.config.target.components[0].spectral.parameters[1].value = (
70
        config_1.config.target.components[0].spectral.parameters[1].value
71
    )
72
    # Have the same value of redshift value and EBL reference model
73
    config_2.config.target.components[0].spectral.ebl_abs.redshift = config_1.config.target.components[
74
        0
75
    ].spectral.ebl_abs.redshift
76
77
    # Make sure the source names are the same
78
    config_2.config.target.source_name = config_1.config.target.source_name
79
    config_2.config.target.components[0].name = config_1.config.target.components[0].name
80
81
    return config_2
82
83
84
def fetch_all_analysis_objects(main_config, spec_model_temp_files, ebl_scale_factor, ebl_model_name):
85
    """For a list of spectral models, initiate AsgardpyAnalysis objects."""
86
    main_analysis_list = {}
87
    spec_models_list = []
88
89
    for temp in spec_model_temp_files:
90
        temp_model = AsgardpyAnalysis(main_config)
91
        temp_model.config.target.models_file = temp
92
93
        temp_model_2 = AsgardpyAnalysis(temp_model.config)
94
95
        update_config(temp_model, temp_model_2)
96
97
        if ebl_scale_factor != 1.0:
98
            temp_model_2.config.target.components[0].spectral.ebl_abs.alpha_norm = ebl_scale_factor
99
100
        if ebl_model_name != "dominguez":
101
            temp_model_2.config.target.components[0].spectral.ebl_abs.reference = ebl_model_name.replace("_", "-")
102
        else:
103
            temp_model_2.config.target.components[
104
                0
105
            ].spectral.ebl_abs.reference = temp_model.config.target.components[0].spectral.ebl_abs.reference
106
107
        spec_tag = temp.name.split(".")[0].split("_")[-1]
108
        spec_models_list.append(spec_tag)
109
        main_analysis_list[spec_tag] = {}
110
111
        main_analysis_list[spec_tag]["Analysis"] = temp_model_2
112
113
    spec_models_list = np.array(spec_models_list)
114
115
    return main_analysis_list, spec_models_list
116
117
118
def fetch_all_analysis_fit_info(main_analysis_list, spec_models_list):
119
    """
120
    For a list of spectral models, with the AsgardpyAnalysis run till the fit
121
    step, get the relevant information for testing the model preference.
122
    """
123
    fit_success_list = []
124
    pref_over_pl_chi2_list = []
125
    stat_list = []
126
    dof_list = []
127
128
    for tag in spec_models_list:
129
        dict_tag = main_analysis_list[tag]["Analysis"].instrument_spectral_info
130
        dict_pl = main_analysis_list["pl"]["Analysis"].instrument_spectral_info
131
132
        # Collect parameters for AIC check
133
        stat = dict_tag["best_fit_stat"]
134
        dof = dict_tag["DoF"]
135
136
        fit_success = main_analysis_list[tag]["Analysis"].fit_result.success
137
138
        fit_success_list.append(fit_success)
139
        stat_list.append(stat)
140
        dof_list.append(dof)
141
142
        # Checking the preference of a "nested" spectral model (observed),
143
        # over Power Law.
144
        if tag == "pl":
145
            main_analysis_list[tag]["Pref_over_pl_chi2"] = 0
146
            main_analysis_list[tag]["Pref_over_pl_pval"] = 0
147
            main_analysis_list[tag]["DoF_over_pl"] = 0
148
            pref_over_pl_chi2_list.append(0)
149
            continue
150
151
        p_pl_x, g_pl_x, ndof_pl_x = check_model_preference_lrt(
152
            dict_pl["best_fit_stat"],
153
            dict_tag["best_fit_stat"],
154
            dict_pl["DoF"],
155
            dict_tag["DoF"],
156
        )
157
158
        main_analysis_list[tag]["Pref_over_pl_chi2"] = g_pl_x
159
        pref_over_pl_chi2_list.append(g_pl_x)
160
        main_analysis_list[tag]["Pref_over_pl_pval"] = p_pl_x
161
        main_analysis_list[tag]["DoF_over_pl"] = ndof_pl_x
162
163
    fit_success_list = np.array(fit_success_list)
164
165
    # Only select fit results that were successful for comparisons
166
    stat_list = np.array(stat_list)[fit_success_list]
167
    dof_list = np.array(dof_list)[fit_success_list]
168
    pref_over_pl_chi2_list = np.array(pref_over_pl_chi2_list)[fit_success_list]
169
170
    return fit_success_list, stat_list, dof_list, pref_over_pl_chi2_list
171
172
173
def tabulate_best_fit_stats(spec_models_list, fit_success_list, main_analysis_list, list_rel_p):
174
    """For a list of spectral models, tabulate the best fit information."""
175
176
    fit_stats_table = []
177
178
    for i, tag in enumerate(spec_models_list[fit_success_list]):
179
        info_ = main_analysis_list[tag]["Analysis"].instrument_spectral_info
180
181
        t = main_analysis_list[tag]
182
183
        ts_gof = round(info_["best_fit_stat"] - info_["max_fit_stat"], 3)
184
        t_fits = {
185
            "Spectral Model": tag.upper(),
186
            "TS of Best Fit": round(info_["best_fit_stat"], 3),
187
            "TS of Max Fit": round(info_["max_fit_stat"], 3),
188
            "TS of Goodness of Fit": ts_gof,
189
            "DoF of Fit": info_["DoF"],
190
            r"Significance ($\sigma$) of Goodness of Fit": round(info_["fit_chi2_sig"], 3),
191
            "p-value of Goodness of Fit": float(f"{info_['fit_pval']:.4g}"),
192
            "Pref over PL (chi2)": round(t["Pref_over_pl_chi2"], 3),
193
            "Pref over PL (p-value)": float(f"{t['Pref_over_pl_pval']:.4g}"),
194
            "Pref over PL (DoF)": t["DoF_over_pl"],
195
            "Relative p-value (AIC)": float(f"{list_rel_p[i]:.4g}"),
196
        }
197
        fit_stats_table.append(t_fits)
198
    stats_table = QTable(fit_stats_table)
199
200
    return stats_table
201
202
203
def write_output_config_yaml(model_):
204
    """With the selected spectral model, update a default config in yaml."""
205
206
    spec_model = model_.spectral_model.model1.to_dict()
207
208
    temp_config = AsgardpyConfig()
209
    temp_config.target.components[0] = spec_model
210
    # Update with the spectral model info
211
    temp_ = temp_config.dict(exclude_defaults=True)
212
213
    # Remove some of the unnecessary keys
214
    temp_["target"].pop("models_file", None)
215
    temp_["target"]["components"][0]["spectral"].pop("ebl_abs", None)
216
217
    yaml_ = yaml.dump(
218
        temp_,
219
        sort_keys=False,
220
        indent=4,
221
        width=80,
222
        default_flow_style=None,
223
    )
224
    return yaml_
225
226
227
def main():
228
    args = parser.parse_args()
229
230
    main_config = AsgardpyConfig.read(args.config)
231
    config_path = Path(args.config)
232
    config_path_file_name = config_path.name.split(".")[0]
233
    target_source_name = main_config.target.source_name
234
235
    steps_list = []
236
    for s in main_config.general.steps:
237
        if s.value != "flux-points":
238
            steps_list.append(s.value)
239
    log.info("Target source is: %s", target_source_name)
240
241
    spec_model_temp_files = get_model_config_files(["lp", "bpl", "ecpl", "pl", "eclp", "sbpl"])
242
243
    main_analysis_list, spec_models_list = fetch_all_analysis_objects(
244
        main_config, spec_model_temp_files, args.ebl_scale_factor, args.ebl_model_name
245
    )
246
247
    # Run Analysis Steps till Fit
248
    PL_idx = 0
249
250
    for i, tag in enumerate(spec_models_list):
251
        log.info("Spectral model being tested: %s", tag)
252
253
        main_analysis_list[tag]["Analysis"].run(steps_list)
254
255
        if tag == "pl":
256
            PL_idx = i
257
258
    fit_success_list, stat_list, dof_list, pref_over_pl_chi2_list = fetch_all_analysis_fit_info(
259
        main_analysis_list, spec_models_list
260
    )
261
262
    # If any spectral model has at least 5 sigmas preference over PL
263
    best_sp_idx_lrt = np.nonzero(pref_over_pl_chi2_list == np.nanmax(pref_over_pl_chi2_list))[0]
264
    for idx in best_sp_idx_lrt:
265
        if pref_over_pl_chi2_list[idx] > 5:
266
            sp_idx_lrt = idx
267
            log.info("Best preferred spectral model over PL is %s", spec_models_list[idx])
268
        else:
269
            sp_idx_lrt = PL_idx
270
            log.info("No other model preferred over PL")
271
272
    list_rel_p = check_model_preference_aic(stat_list, dof_list)
273
274
    best_sp_idx_aic = np.nonzero(list_rel_p == np.nanmax(list_rel_p))[0]
275
276
    for idx in best_sp_idx_aic:
277
        if list_rel_p[idx] > 0.95:
278
            sp_idx_aic = idx
279
            log.info("Best preferred spectral model is %s", spec_models_list[fit_success_list][idx])
280
        else:
281
            sp_idx_aic = PL_idx
282
            log.info("No other model preferred, hence PL is selected")
283
284
    stats_table = tabulate_best_fit_stats(spec_models_list, fit_success_list, main_analysis_list, list_rel_p)
285
286
    stats_table.meta["Target source name"] = target_source_name
287
    stats_table.meta["EBL model"] = args.ebl_model_name
288
    stats_table.meta["EBL scale factor"] = args.ebl_scale_factor
289
290
    file_name = f"{config_path_file_name}_{args.ebl_model_name}_{args.ebl_scale_factor}_fit_stats.ecsv"
291
    stats_table.write(
292
        main_config.general.outdir / file_name,
293
        format="ascii.ecsv",
294
        overwrite=True,
295
    )
296
297
    if args.write_config:
298
        log.info("Write the spectral model")
299
300
        for idx, name in zip([sp_idx_lrt, sp_idx_aic], ["lrt", "aic"], strict=False):
0 ignored issues
show
introduced by
The variable sp_idx_aic does not seem to be defined in case the for loop on line 276 is not entered. Are you sure this can never be the case?
Loading history...
introduced by
The variable sp_idx_lrt does not seem to be defined in case the for loop on line 264 is not entered. Are you sure this can never be the case?
Loading history...
301
            tag = spec_models_list[fit_success_list][idx]
302
303
            path = config_path.parent / f"{config_path_file_name}_model_most_pref_{name}.yaml"
304
305
            yaml_ = write_output_config_yaml(main_analysis_list[tag]["Analysis"].final_model[0])
306
            path.write_text(yaml_)
307
308
309
if __name__ == "__main__":
310
    main()
311