FluxPointsAnalysisStep._sort_datasets_info()   A
last analyzed

Complexity

Conditions 5

Size

Total Lines 28
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 14
nop 1
dl 0
loc 28
rs 9.2333
c 0
b 0
f 0
1
"""
2
Main classes to define High-level Analysis Config and the Analysis Steps.
3
"""
4
5
from enum import Enum
6
7
from astropy import units as u
8
from gammapy.datasets import Datasets, FluxPointsDataset
9
from gammapy.estimators import FluxPointsEstimator
10
from gammapy.modeling import Fit
11
from gammapy.utils.metadata import CreatorMetaData
12
13
from asgardpy.analysis.step_base import AnalysisStepBase
14
from asgardpy.base import BaseConfig, EnergyRangeConfig
15
from asgardpy.version import __public_version__
16
17
__all__ = [
18
    "FitAnalysisStep",
19
    "FitConfig",
20
    "FluxPointsAnalysisStep",
21
    "FluxPointsConfig",
22
]
23
24
25
# Defining various components of High-level Analysis Config
26
class BackendEnum(str, Enum):
27
    """Config section for a list Fitting backend methods."""
28
29
    minuit = "minuit"
30
    scipy = "scipy"
31
32
33
class FitConfig(BaseConfig):
34
    """Config section for parameters to use for Fit function."""
35
36
    fit_range: EnergyRangeConfig = EnergyRangeConfig()
37
    backend: BackendEnum = BackendEnum.minuit
38
    optimize_opts: dict = {}
39
    covariance_opts: dict = {}
40
    confidence_opts: dict = {}
41
    store_trace: bool = True
42
43
44
class FluxPointsConfig(BaseConfig):
45
    """Config section for parameters to use for FluxPointsEstimator function."""
46
47
    parameters: dict = {"selection_optional": "all"}
48
    reoptimize: bool = False
49
50
51
# The main Analysis Steps
52
class FitAnalysisStep(AnalysisStepBase):
53
    """
54
    Using the Fitting parameters as defined in the Config, with the given
55
    datasets perform the fit of the models to the updated list of datasets.
56
    """
57
58
    tag = "fit"
59
60
    def _run(self):
61
        self.fit_params = self.config.fit_params
62
63
        self._setup_fit()
64
        final_dataset = self._set_datasets()
65
        self.fit_result = self.fit.run(datasets=final_dataset)
66
67
        self.log.info(self.fit_result)
68
69
    def _setup_fit(self):
70
        """
71
        Setup the Gammapy Fit function with all the provided parameters from
72
        the config.
73
        """
74
        self.fit = Fit(
75
            backend=self.fit_params.backend,
76
            optimize_opts=self.fit_params.optimize_opts,
77
            covariance_opts=self.fit_params.covariance_opts,
78
            confidence_opts=self.fit_params.confidence_opts,
79
            store_trace=self.fit_params.store_trace,
80
        )
81
82
    def _set_datasets(self):
83
        """
84
        Prepare each dataset for running the Fit function, by setting the
85
        energy range.
86
        """
87
        en_min = u.Quantity(self.fit_params.fit_range.min)
88
        en_max = u.Quantity(self.fit_params.fit_range.max)
89
90
        final_dataset = Datasets()
91
        for data in self.datasets:
92
            if not isinstance(data, FluxPointsDataset):
93
                geom = data.counts.geom
94
                data.mask_fit = geom.energy_mask(en_min, en_max)
95
            final_dataset.append(data)
96
97
        return final_dataset
98
99
100
class FluxPointsAnalysisStep(AnalysisStepBase):
101
    """
102
    Using the Flux Points Estimator parameters in the config, and the given
103
    datasets and instrument_spectral_info perform the Flux Points Estimation
104
    and store the result in a list of flux points for each dataset.
105
    """
106
107
    tag = "flux-points"
108
109
    def _run(self):
110
        self.flux_points = []
111
        datasets, energy_edges = self._sort_datasets_info()
112
113
        for dataset, energy_edge in zip(datasets, energy_edges, strict=True):
114
            self._set_fpe(energy_edge)
115
            flux_points = self.fpe.run(datasets=dataset)
116
            flux_points.name = dataset.names
117
118
            flux_points.meta["creation"] = CreatorMetaData()
119
            flux_points.meta["creation"].creator += f", Asgardpy {__public_version__}"
120
            flux_points.meta["optional"] = {
121
                "instrument": flux_points.name,
122
            }
123
124
            self.flux_points.append(flux_points)
125
126
    def _set_fpe(self, energy_bin_edges):
127
        """
128
        Setup the Gammapy FluxPointsEstimator function with all the
129
        provided parameters.
130
        """
131
        fpe_settings = self.config.flux_points_params.parameters
132
133
        self.fpe = FluxPointsEstimator(
134
            energy_edges=energy_bin_edges,
135
            source=self.config.target.source_name,
136
            n_jobs=self.config.general.n_jobs,
137
            parallel_backend=self.config.general.parallel_backend,
138
            reoptimize=self.config.flux_points_params.reoptimize,
139
            **fpe_settings,
140
        )
141
142
    def _sort_datasets_info(self):
143
        """
144
        The given list of datasets may contain sub-instrument level datasets.
145
        With the help of the dict information for instrument specific name and
146
        spectral energy edges, this function, sorts the datasets and returns
147
        them to be passed to the Flux Points Estimator function.
148
149
        Returns
150
        -------
151
        sorted_datasets: List of Datasets object.
152
        sorted_energy_edges: List of energy edges for flux points estimation
153
            for respective instruments' datasets
154
        """
155
        dataset_name_list = self.datasets.names
156
        sorted_datasets = []
157
        sorted_energy_edges = []
158
159
        for i, name in enumerate(self.instrument_spectral_info["name"]):
160
            dataset_list = []
161
            for j, dataset_names in enumerate(dataset_name_list):
162
                if name in dataset_names:
163
                    dataset_list.append(self.datasets[j])
164
            if len(dataset_list) != 0:
165
                sorted_energy_edges.append(self.instrument_spectral_info["spectral_energy_ranges"][i])
166
                dataset_list = Datasets(dataset_list)
167
                sorted_datasets.append(dataset_list)
168
169
        return sorted_datasets, sorted_energy_edges
170