Completed
Push — main ( 2edd10...a5707d )
by Chaitanya
28s queued 16s
created

Dataset1DGeneration.update_dataset()   A

Complexity

Conditions 3

Size

Total Lines 17
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 2
dl 0
loc 17
rs 10
c 0
b 0
f 0
1
"""
2
Main classes to define 1D Dataset Config, 1D Dataset Analysis Step and
3
to generate 1D Datasets from given Instruments' DL3 data from the config.
4
"""
5
6
import logging
7
8
import numpy as np
9
from astropy import units as u
10
from gammapy.datasets import Datasets
11
12
from asgardpy.analysis.step_base import AnalysisStepBase
13
from asgardpy.base.base import BaseConfig
14
from asgardpy.base.geom import (
15
    GeomConfig,
16
    SkyPositionConfig,
17
    generate_geom,
18
    get_source_position,
19
)
20
from asgardpy.base.reduction import (
21
    BackgroundConfig,
22
    MapSelectionEnum,
23
    ObservationsConfig,
24
    ReductionTypeEnum,
25
    SafeMaskConfig,
26
    generate_dl4_dataset,
27
    get_bkg_maker,
28
    get_dataset_maker,
29
    get_dataset_reference,
30
    get_exclusion_region_mask,
31
    get_filtered_observations,
32
    get_safe_mask_maker,
33
)
34
from asgardpy.io.input_dl3 import InputDL3Config
35
from asgardpy.io.io_dl4 import DL4BaseConfig, DL4Files, get_reco_energy_bins
36
from asgardpy.version import __public_version__
37
38
__all__ = [
39
    "Datasets1DAnalysisStep",
40
    "Dataset1DBaseConfig",
41
    "Dataset1DConfig",
42
    "Dataset1DGeneration",
43
    "Dataset1DInfoConfig",
44
]
45
46
log = logging.getLogger(__name__)
47
48
49
# Defining various components of 1D Dataset Config section
50
class Dataset1DInfoConfig(BaseConfig):
51
    """Config section for 1D DL3 Dataset Reduction for each instrument."""
52
53
    name: str = "dataset-name"
54
    geom: GeomConfig = GeomConfig()
55
    observation: ObservationsConfig = ObservationsConfig()
56
    background: BackgroundConfig = BackgroundConfig()
57
    safe_mask: SafeMaskConfig = SafeMaskConfig()
58
    on_region: SkyPositionConfig = SkyPositionConfig()
59
    containment_correction: bool = True
60
    map_selection: list[MapSelectionEnum] = []
61
62
63
class Dataset1DBaseConfig(BaseConfig):
64
    """
65
    Config section for 1D DL3 Dataset base information for each instrument.
66
    """
67
68
    name: str = "Instrument-name"
69
    input_dl3: list[InputDL3Config] = [InputDL3Config()]
70
    input_dl4: bool = False
71
    dataset_info: Dataset1DInfoConfig = Dataset1DInfoConfig()
72
    dl4_dataset_info: DL4BaseConfig = DL4BaseConfig()
73
74
75
class Dataset1DConfig(BaseConfig):
76
    """Config section for a list of all 1D DL3 Datasets information."""
77
78
    type: ReductionTypeEnum = ReductionTypeEnum.spectrum
79
    instruments: list[Dataset1DBaseConfig] = [Dataset1DBaseConfig()]
80
81
82
# The main Analysis Step
83
class Datasets1DAnalysisStep(AnalysisStepBase):
84
    """
85
    From the given config information, prepare the full list of 1D datasets,
86
    iterating over all the Instruments' information by running the
87
    Dataset1DGeneration function.
88
    """
89
90
    tag = "datasets-1d"
91
92
    def _run(self):
93
        instruments_list = self.config.dataset1d.instruments
94
        self.log.info("%d number of 1D Datasets given", len(instruments_list))
95
96
        datasets_1d_final = Datasets()
97
        instrument_spectral_info = {"name": [], "spectral_energy_ranges": []}
98
99
        # Calculate the total number of reconstructed energy bins used
100
        en_bins = 0
101
102
        # Iterate over all instrument information given:
103
        for i in np.arange(len(instruments_list)):
104
            config_1d_dataset = instruments_list[i]
105
            instrument_spectral_info["name"].append(config_1d_dataset.name)
106
            dl4_files = DL4Files(config_1d_dataset.dl4_dataset_info, self.log)
107
108
            if not config_1d_dataset.input_dl4:
109
                generate_1d_dataset = Dataset1DGeneration(self.log, config_1d_dataset, self.config)
110
                dataset = generate_1d_dataset.run()
111
            else:
112
                dataset = dl4_files.get_dl4_dataset(config_1d_dataset.dataset_info.observation)
113
114
            energy_bin_edges = dl4_files.get_spectral_energies()
115
            instrument_spectral_info["spectral_energy_ranges"].append(energy_bin_edges)
116
117
            datasets_1d_final, en_bins = update_final_1d_datasets(
118
                datasets_1d_final, dataset, config_1d_dataset.name, en_bins, self.config.general.stacked_dataset
119
            )
120
121
        instrument_spectral_info["en_bins"] = en_bins
122
123
        # No linked model parameters or other free model parameters taken here
124
        instrument_spectral_info["free_params"] = 0
125
126
        return (
127
            datasets_1d_final,
128
            None,
129
            instrument_spectral_info,
130
        )
131
132
133
def update_final_1d_datasets(datasets_1d_final, dataset, config_1d_dataset_name, en_bins, stacked_dataset=False):
134
    """
135
    Updating the final 1D datasets with appropriate update of the Metadata and
136
    also the information of the energy bins.
137
    """
138
    if stacked_dataset:
139
        dataset = dataset.stack_reduce(name=config_1d_dataset_name)
140
        dataset._meta.optional = {
141
            "instrument": config_1d_dataset_name,
142
        }
143
        dataset._meta.creation.creator += f", Asgardpy {__public_version__}"
144
145
        en_bins = get_reco_energy_bins(dataset, en_bins)
146
        datasets_1d_final.append(dataset)
147
    else:
148
        for data in dataset:
149
            data._meta.optional = {
150
                "instrument": config_1d_dataset_name,
151
            }
152
            data._meta.creation.creator += f", Asgardpy {__public_version__}"
153
            en_bins = get_reco_energy_bins(data, en_bins)
154
            datasets_1d_final.append(data)
155
156
    return datasets_1d_final, en_bins
157
158
159
class Dataset1DGeneration:
160
    """
161
    Class for 1D dataset creation based on the config or AsgardpyConfig
162
    information provided on the 1D dataset and the target source.
163
164
    Runs the following steps:
165
166
    1. Read the DL3 files of 1D datasets into DataStore object.
167
168
    2. Perform any Observation selection, based on Observation IDs or time intervals.
169
170
    3. Create the base dataset reference, including the main counts geometry.
171
172
    4. Prepare standard data reduction makers using the parameters passed in the config.
173
174
    5. Generate the final dataset.
175
    """
176
177
    def __init__(self, log, config_1d_dataset, config_full):
178
        self.config_1d_dataset_io = config_1d_dataset.input_dl3
179
        self.log = log
180
        self.config_1d_dataset_info = config_1d_dataset.dataset_info
181
        self.config_target = config_full.target
182
        self.n_jobs = config_full.general.n_jobs
183
        self.parallel_backend = config_full.general.parallel_backend
184
        self.exclusion_regions = []
185
        self.datasets = Datasets()
186
187
    def run(self):
188
        """
189
        Main function to run the creation of 1D dataset.
190
        """
191
        # Applying all provided filters to get the Observations object
192
        observations = get_filtered_observations(
193
            dl3_path=self.config_1d_dataset_io[0].input_dir,
194
            obs_config=self.config_1d_dataset_info.observation,
195
            log=self.log,
196
        )
197
        # Get dict information of the ON region, with its SkyCoord position and angular radius
198
        center_pos = get_source_position(target_region=self.config_1d_dataset_info.on_region)
199
200
        # Create the main counts geometry
201
        geom = generate_geom(tag="1d", geom_config=self.config_1d_dataset_info.geom, center_pos=center_pos)
202
203
        # Get all the Dataset reduction makers
204
        dataset_reference = get_dataset_reference(
205
            tag="1d", geom=geom, geom_config=self.config_1d_dataset_info.geom
206
        )
207
208
        dataset_maker = get_dataset_maker(
209
            tag="1d",
210
            dataset_config=self.config_1d_dataset_info,
211
        )
212
213
        safe_maker = get_safe_mask_maker(safe_config=self.config_1d_dataset_info.safe_mask)
214
215
        excluded_geom = generate_geom(
216
            tag="1d-ex", geom_config=self.config_1d_dataset_info.geom, center_pos=center_pos
217
        )
218
        exclusion_mask = get_exclusion_region_mask(
219
            exclusion_params=self.config_1d_dataset_info.background.exclusion,
220
            exclusion_regions=self.exclusion_regions,
221
            excluded_geom=excluded_geom,
222
            config_target=self.config_target,
223
            geom_config=self.config_1d_dataset_info.geom,
224
            log=self.log,
225
        )
226
227
        bkg_maker = get_bkg_maker(
228
            bkg_config=self.config_1d_dataset_info.background,
229
            exclusion_mask=exclusion_mask,
230
        )
231
232
        # Produce the final Dataset
233
        self.datasets = generate_dl4_dataset(
234
            tag="1d",
235
            observations=observations,
236
            dataset_reference=dataset_reference,
237
            dataset_maker=dataset_maker,
238
            bkg_maker=bkg_maker,
239
            safe_maker=safe_maker,
240
            n_jobs=self.n_jobs,
241
            parallel_backend=self.parallel_backend,
242
        )
243
        self.update_dataset(observations)
244
245
        return self.datasets
246
247
    def update_dataset(self, observations):
248
        """
249
        Update the datasets generated by DatasetsMaker with names as per the
250
        Observation ID and if a custom safe energy mask is provided in the
251
        config, apply it to each dataset accordingly.
252
        """
253
        safe_cfg = self.config_1d_dataset_info.safe_mask
254
        pars = safe_cfg.parameters
255
256
        for data, obs in zip(self.datasets, observations, strict=True):
257
            # Rename the datasets using the appropriate Obs ID
258
            data._name = str(obs.obs_id)
259
260
            # Use custom safe energy mask
261
            if "custom-mask" in safe_cfg.methods:
262
                data.mask_safe = data.counts.geom.energy_mask(
263
                    energy_min=u.Quantity(pars["min"]), energy_max=u.Quantity(pars["max"]), round_to_edge=True
264
                )
265