gammapy.analysis.core   F
last analyzed

Complexity

Total Complexity 92

Size/Duplication

Total Lines 594
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 394
dl 0
loc 594
rs 2
c 0
b 0
f 0
wmc 92

27 Methods

Rating   Name   Duplication   Size   Complexity  
A Analysis.__init__() 0 9 1
A Analysis.config() 0 4 3
B Analysis._make_obs_table_selection() 0 35 6
A Analysis.models() 0 5 1
A Analysis._set_data_store() 0 11 3
D Analysis.set_models() 0 38 12
A Analysis._create_dataset_maker() 0 19 4
A Analysis._make_energy_axis() 0 15 5
A Analysis._create_region_geometry() 0 9 1
B Analysis._create_background_maker() 0 34 6
A Analysis.write_models() 0 13 2
A Analysis._create_geometry() 0 15 3
A Analysis.run_fit() 0 17 4
A Analysis._create_wcs_geometry() 0 25 3
A Analysis._create_reference_dataset() 0 18 4
A Analysis.update_config() 0 2 1
A Analysis.get_light_curve() 0 34 3
A Analysis.get_flux_points() 0 25 2
A Analysis.read_datasets() 0 15 3
A Analysis.get_excess_map() 0 23 4
A Analysis.read_models() 0 15 1
B Analysis._spectrum_extraction() 0 29 5
A Analysis._map_making() 0 25 1
A Analysis.get_datasets() 0 10 4
A Analysis.write_datasets() 0 25 2
A Analysis._create_safe_mask_maker() 0 7 1
A Analysis.get_observations() 0 26 4

How to fix   Complexity   

Complexity

Complex classes like gammapy.analysis.core often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
"""Session class driving the high level interface API"""
3
import logging
4
from astropy.coordinates import SkyCoord
5
from astropy.table import Table
6
from regions import CircleSkyRegion
7
from gammapy.analysis.config import AnalysisConfig
8
from gammapy.data import DataStore
9
from gammapy.datasets import Datasets, FluxPointsDataset, MapDataset, SpectrumDataset
10
from gammapy.estimators import (
11
    ExcessMapEstimator,
12
    FluxPointsEstimator,
13
    LightCurveEstimator,
14
)
15
from gammapy.makers import (
16
    DatasetsMaker,
17
    FoVBackgroundMaker,
18
    MapDatasetMaker,
19
    ReflectedRegionsBackgroundMaker,
20
    RingBackgroundMaker,
21
    SafeMaskMaker,
22
    SpectrumDatasetMaker,
23
)
24
from gammapy.maps import Map, MapAxis, RegionGeom, WcsGeom
25
from gammapy.modeling import Fit
26
from gammapy.modeling.models import DatasetModels, FoVBackgroundModel, Models
27
from gammapy.utils.pbar import progress_bar
28
from gammapy.utils.scripts import make_path
29
30
__all__ = ["Analysis"]
31
32
log = logging.getLogger(__name__)
33
34
35
class Analysis:
36
    """Config-driven high level analysis interface.
37
38
    It is initialized by default with a set of configuration parameters and values declared in
39
    an internal high level interface model, though the user can also provide configuration
40
    parameters passed as a nested dictionary at the moment of instantiation. In that case these
41
    parameters will overwrite the default values of those present in the configuration file.
42
43
    Parameters
44
    ----------
45
    config : dict or `AnalysisConfig`
46
        Configuration options following `AnalysisConfig` schema
47
    """
48
49
    def __init__(self, config):
50
        self.config = config
51
        self.config.set_logging()
52
        self.datastore = None
53
        self.observations = None
54
        self.datasets = None
55
        self.fit = Fit()
56
        self.fit_result = None
57
        self.flux_points = None
58
59
    @property
60
    def models(self):
61
        if not self.datasets:
62
            raise RuntimeError("No datasets defined. Impossible to set models.")
63
        return self.datasets.models
64
65
    @models.setter
66
    def models(self, models):
67
        self.set_models(models, extend=False)
68
69
    @property
70
    def config(self):
71
        """Analysis configuration (`AnalysisConfig`)"""
72
        return self._config
73
74
    @config.setter
75
    def config(self, value):
76
        if isinstance(value, dict):
77
            self._config = AnalysisConfig(**value)
78
        elif isinstance(value, AnalysisConfig):
79
            self._config = value
80
        else:
81
            raise TypeError("config must be dict or AnalysisConfig.")
82
83
    def _set_data_store(self):
84
        """Set the datastore on the Analysis object."""
85
        path = make_path(self.config.observations.datastore)
86
        if path.is_file():
87
            log.debug(f"Setting datastore from file: {path}")
88
            self.datastore = DataStore.from_file(path)
89
        elif path.is_dir():
90
            log.debug(f"Setting datastore from directory: {path}")
91
            self.datastore = DataStore.from_dir(path)
92
        else:
93
            raise FileNotFoundError(f"Datastore not found: {path}")
94
95
    def _make_obs_table_selection(self):
96
        """Return list of obs_ids after filtering on datastore observation table."""
97
        obs_settings = self.config.observations
98
99
        # Reject configs with list of obs_ids and obs_file set at the same time
100
        if len(obs_settings.obs_ids) and obs_settings.obs_file is not None:
101
            raise ValueError(
102
                "Values for both parameters obs_ids and obs_file are not accepted."
103
            )
104
105
        # First select input list of observations from obs_table
106
        if len(obs_settings.obs_ids):
107
            selected_obs_table = self.datastore.obs_table.select_obs_id(
108
                obs_settings.obs_ids
109
            )
110
        elif obs_settings.obs_file is not None:
111
            path = make_path(obs_settings.obs_file)
112
            ids = list(Table.read(path, format="ascii", data_start=0).columns[0])
113
            selected_obs_table = self.datastore.obs_table.select_obs_id(ids)
114
        else:
115
            selected_obs_table = self.datastore.obs_table
116
117
        # Apply cone selection
118
        if obs_settings.obs_cone.lon is not None:
119
            cone = dict(
120
                type="sky_circle",
121
                frame=obs_settings.obs_cone.frame,
122
                lon=obs_settings.obs_cone.lon,
123
                lat=obs_settings.obs_cone.lat,
124
                radius=obs_settings.obs_cone.radius,
125
                border="0 deg",
126
            )
127
            selected_obs_table = selected_obs_table.select_observations(cone)
128
129
        return selected_obs_table["OBS_ID"].tolist()
130
131
    def get_observations(self):
132
        """Fetch observations from the data store according to criteria defined
133
        in the configuration."""
134
        observations_settings = self.config.observations
135
        self._set_data_store()
136
137
        log.info("Fetching observations.")
138
        ids = self._make_obs_table_selection()
139
        required_irf = [_.value for _ in observations_settings.required_irf]
140
        self.observations = self.datastore.get_observations(
141
            ids, skip_missing=True, required_irf=required_irf
142
        )
143
144
        if observations_settings.obs_time.start is not None:
145
            start = observations_settings.obs_time.start
146
            stop = observations_settings.obs_time.stop
147
            if len(start.shape) == 0:
148
                time_intervals = [(start, stop)]
149
            else:
150
                time_intervals = [(tstart, tstop) for tstart, tstop in zip(start, stop)]
151
            self.observations = self.observations.select_time(time_intervals)
152
153
        log.info(f"Number of selected observations: {len(self.observations)}")
154
155
        for obs in self.observations:
156
            log.debug(obs)
157
158
    def get_datasets(self):
159
        """Produce reduced datasets."""
160
        datasets_settings = self.config.datasets
161
        if not self.observations or len(self.observations) == 0:
162
            raise RuntimeError("No observations have been selected.")
163
164
        if datasets_settings.type == "1d":
165
            self._spectrum_extraction()
166
        else:  # 3d
167
            self._map_making()
168
169
    def set_models(self, models, extend=True):
170
        """Set models on datasets.
171
        Adds `FoVBackgroundModel` if not present already
172
173
        Parameters
174
        ----------
175
        models : `~gammapy.modeling.models.Models` or str
176
            Models object or YAML models string
177
        extend : bool
178
            Extend the exiting models on the datasets or replace them.
179
        """
180
        if not self.datasets or len(self.datasets) == 0:
181
            raise RuntimeError("Missing datasets")
182
183
        log.info("Reading model.")
184
        if isinstance(models, str):
185
            models = Models.from_yaml(models)
186
        elif isinstance(models, Models):
187
            pass
188
        elif isinstance(models, DatasetModels) or isinstance(models, list):
189
            models = Models(models)
190
        else:
191
            raise TypeError(f"Invalid type: {models!r}")
192
193
        if extend:
194
            models.extend(self.datasets.models)
195
196
        self.datasets.models = models
197
198
        bkg_models = []
199
        for dataset in self.datasets:
200
            if dataset.tag == "MapDataset" and dataset.background_model is None:
201
                bkg_models.append(FoVBackgroundModel(dataset_name=dataset.name))
202
        if bkg_models:
203
            models.extend(bkg_models)
204
            self.datasets.models = models
205
206
        log.info(models)
207
208
    def read_models(self, path, extend=True):
209
        """Read models from YAML file.
210
211
        Parameters
212
        ----------
213
        path : str
214
            path to the model file
215
        extend : bool
216
            Extend the exiting models on the datasets or replace them.
217
        """
218
219
        path = make_path(path)
220
        models = Models.read(path)
221
        self.set_models(models, extend=extend)
222
        log.info(f"Models loaded from {path}.")
223
224
    def write_models(self, overwrite=True, write_covariance=True):
225
        """Write models to YAML file.
226
        File name is taken from the configuration file.
227
        """
228
229
        filename_models = self.config.general.models_file
230
        if filename_models is not None:
231
            self.models.write(
232
                filename_models, overwrite=overwrite, write_covariance=write_covariance
233
            )
234
            log.info(f"Models loaded from {filename_models}.")
235
        else:
236
            raise RuntimeError("Missing models_file in config.general")
237
238
    def read_datasets(self):
239
        """Read datasets from YAML file.
240
        File names are taken from the configuration file.
241
242
        """
243
244
        filename = self.config.general.datasets_file
245
        filename_models = self.config.general.models_file
246
        if filename is not None:
247
            self.datasets = Datasets.read(filename)
248
            log.info(f"Datasets loaded from {filename}.")
249
            if filename_models is not None:
250
                self.read_models(filename_models, extend=False)
251
        else:
252
            raise RuntimeError("Missing datasets_file in config.general")
253
254
    def write_datasets(self, overwrite=True, write_covariance=True):
255
        """Write datasets to YAML file.
256
        File names are taken from the configuration file.
257
258
        Parameters
259
        ----------
260
        overwrite : bool
261
            overwrite datasets FITS files
262
        write_covariance : bool
263
            save covariance or not
264
        """
265
266
        filename = self.config.general.datasets_file
267
        filename_models = self.config.general.models_file
268
        if filename is not None:
269
            self.datasets.write(
270
                filename,
271
                filename_models,
272
                overwrite=overwrite,
273
                write_covariance=write_covariance,
274
            )
275
            log.info(f"Datasets stored to {filename}.")
276
            log.info(f"Datasets stored to {filename_models}.")
277
        else:
278
            raise RuntimeError("Missing datasets_file in config.general")
279
280
    def run_fit(self):
281
        """Fitting reduced datasets to model."""
282
        if not self.models:
283
            raise RuntimeError("Missing models")
284
285
        fit_settings = self.config.fit
286
        for dataset in self.datasets:
287
            if fit_settings.fit_range:
288
                energy_min = fit_settings.fit_range.min
289
                energy_max = fit_settings.fit_range.max
290
                geom = dataset.counts.geom
291
                dataset.mask_fit = geom.energy_mask(energy_min, energy_max)
292
293
        log.info("Fitting datasets.")
294
        result = self.fit.run(datasets=self.datasets)
295
        self.fit_result = result
296
        log.info(self.fit_result)
297
298
    def get_flux_points(self):
299
        """Calculate flux points for a specific model component."""
300
        if not self.datasets:
301
            raise RuntimeError(
302
                "No datasets defined. Impossible to compute flux points."
303
            )
304
305
        fp_settings = self.config.flux_points
306
        log.info("Calculating flux points.")
307
        energy_edges = self._make_energy_axis(fp_settings.energy).edges
308
        flux_point_estimator = FluxPointsEstimator(
309
            energy_edges=energy_edges,
310
            source=fp_settings.source,
311
            fit=self.fit,
312
            **fp_settings.parameters,
313
        )
314
315
        fp = flux_point_estimator.run(datasets=self.datasets)
316
317
        self.flux_points = FluxPointsDataset(
318
            data=fp, models=self.models[fp_settings.source]
319
        )
320
        cols = ["e_ref", "dnde", "dnde_ul", "dnde_err", "sqrt_ts"]
321
        table = self.flux_points.data.to_table(sed_type="dnde")
322
        log.info("\n{}".format(table[cols]))
323
324
    def get_excess_map(self):
325
        """Calculate excess map with respect to the current model."""
326
        excess_settings = self.config.excess_map
327
        log.info("Computing excess maps.")
328
329
        # TODO: Here we could possibly stack the datasets if needed
330
        # or allow to compute the excess map for each dataset
331
        if len(self.datasets) > 1:
332
            raise ValueError("Datasets must be stacked to compute the excess map")
333
334
        if self.datasets[0].tag not in ["MapDataset", "MapDatasetOnOff"]:
335
            raise ValueError("Cannot compute excess map for 1D dataset")
336
337
        energy_edges = self._make_energy_axis(excess_settings.energy_edges)
338
        if energy_edges is not None:
339
            energy_edges = energy_edges.edges
340
341
        excess_map_estimator = ExcessMapEstimator(
342
            correlation_radius=excess_settings.correlation_radius,
343
            energy_edges=energy_edges,
344
            **excess_settings.parameters,
345
        )
346
        self.excess_map = excess_map_estimator.run(self.datasets[0])
347
348
    def get_light_curve(self):
349
        """Calculate light curve for a specific model component."""
350
        lc_settings = self.config.light_curve
351
        log.info("Computing light curve.")
352
        energy_edges = self._make_energy_axis(lc_settings.energy_edges).edges
353
354
        if (
355
            lc_settings.time_intervals.start is None
356
            or lc_settings.time_intervals.stop is None
357
        ):
358
            log.info(
359
                "Time intervals not defined. Extract light curve on datasets GTIs."
360
            )
361
            time_intervals = None
362
        else:
363
            time_intervals = [
364
                (t1, t2)
365
                for t1, t2 in zip(
366
                    lc_settings.time_intervals.start, lc_settings.time_intervals.stop
367
                )
368
            ]
369
370
        light_curve_estimator = LightCurveEstimator(
371
            time_intervals=time_intervals,
372
            energy_edges=energy_edges,
373
            source=lc_settings.source,
374
            fit=self.fit,
375
            **lc_settings.parameters,
376
        )
377
        lc = light_curve_estimator.run(datasets=self.datasets)
378
        self.light_curve = lc
379
        log.info(
380
            "\n{}".format(
381
                self.light_curve.to_table(format="lightcurve", sed_type="flux")
382
            )
383
        )
384
385
    def update_config(self, config):
386
        self.config = self.config.update(config=config)
387
388
    @staticmethod
389
    def _create_wcs_geometry(wcs_geom_settings, axes):
390
        """Create the WCS geometry."""
391
        geom_params = {}
392
        skydir_settings = wcs_geom_settings.skydir
393
        if skydir_settings.lon is not None:
394
            skydir = SkyCoord(
395
                skydir_settings.lon, skydir_settings.lat, frame=skydir_settings.frame
396
            )
397
            geom_params["skydir"] = skydir
398
399
        if skydir_settings.frame in ["icrs", "galactic"]:
400
            geom_params["frame"] = skydir_settings.frame
401
        else:
402
            raise ValueError(
403
                f"Incorrect skydir frame: expect 'icrs' or 'galactic'. Got {skydir_settings.frame}"
404
            )
405
406
        geom_params["axes"] = axes
407
        geom_params["binsz"] = wcs_geom_settings.binsize
408
        width = wcs_geom_settings.width.width.to("deg").value
409
        height = wcs_geom_settings.width.height.to("deg").value
410
        geom_params["width"] = (width, height)
411
412
        return WcsGeom.create(**geom_params)
413
414
    @staticmethod
415
    def _create_region_geometry(on_region_settings, axes):
416
        """Create the region geometry."""
417
        on_lon = on_region_settings.lon
418
        on_lat = on_region_settings.lat
419
        on_center = SkyCoord(on_lon, on_lat, frame=on_region_settings.frame)
420
        on_region = CircleSkyRegion(on_center, on_region_settings.radius)
421
422
        return RegionGeom.create(region=on_region, axes=axes)
423
424
    def _create_geometry(self):
425
        """Create the geometry."""
426
        log.debug("Creating geometry.")
427
        datasets_settings = self.config.datasets
428
        geom_settings = datasets_settings.geom
429
        axes = [self._make_energy_axis(geom_settings.axes.energy)]
430
        if datasets_settings.type == "3d":
431
            geom = self._create_wcs_geometry(geom_settings.wcs, axes)
432
        elif datasets_settings.type == "1d":
433
            geom = self._create_region_geometry(datasets_settings.on_region, axes)
434
        else:
435
            raise ValueError(
436
                f"Incorrect dataset type. Expect '1d' or '3d'. Got {datasets_settings.type}."
437
            )
438
        return geom
439
440
    def _create_reference_dataset(self, name=None):
441
        """Create the reference dataset for the current analysis."""
442
        log.debug("Creating target Dataset.")
443
        geom = self._create_geometry()
444
445
        geom_settings = self.config.datasets.geom
446
        geom_irf = dict(energy_axis_true=None, binsz_irf=None)
447
        if geom_settings.axes.energy_true.min is not None:
448
            geom_irf["energy_axis_true"] = self._make_energy_axis(
449
                geom_settings.axes.energy_true, name="energy_true"
450
            )
451
        if geom_settings.wcs.binsize_irf is not None:
452
            geom_irf["binsz_irf"] = geom_settings.wcs.binsize_irf.to("deg").value
453
454
        if self.config.datasets.type == "1d":
455
            return SpectrumDataset.create(geom, name=name, **geom_irf)
456
        else:
457
            return MapDataset.create(geom, name=name, **geom_irf)
458
459
    def _create_dataset_maker(self):
460
        """Create the Dataset Maker."""
461
        log.debug("Creating the target Dataset Maker.")
462
463
        datasets_settings = self.config.datasets
464
        if datasets_settings.type == "3d":
465
            maker = MapDatasetMaker(selection=datasets_settings.map_selection)
466
        elif datasets_settings.type == "1d":
467
            maker_config = {}
468
            if datasets_settings.containment_correction:
469
                maker_config[
470
                    "containment_correction"
471
                ] = datasets_settings.containment_correction
472
473
            maker_config["selection"] = ["counts", "exposure", "edisp"]
474
475
            maker = SpectrumDatasetMaker(**maker_config)
476
477
        return maker
0 ignored issues
show
introduced by
The variable maker does not seem to be defined for all execution paths.
Loading history...
478
479
    def _create_safe_mask_maker(self):
480
        """Create the SafeMaskMaker."""
481
        log.debug("Creating the mask_safe Maker.")
482
483
        safe_mask_selection = self.config.datasets.safe_mask.methods
484
        safe_mask_settings = self.config.datasets.safe_mask.parameters
485
        return SafeMaskMaker(methods=safe_mask_selection, **safe_mask_settings)
486
487
    def _create_background_maker(self):
488
        """Create the Background maker."""
489
        log.info("Creating the background Maker.")
490
491
        datasets_settings = self.config.datasets
492
        bkg_maker_config = {}
493
        if datasets_settings.background.exclusion:
494
            path = make_path(datasets_settings.background.exclusion)
495
            exclusion_mask = Map.read(path)
496
            exclusion_mask.data = exclusion_mask.data.astype(bool)
497
            bkg_maker_config["exclusion_mask"] = exclusion_mask
498
        bkg_maker_config.update(datasets_settings.background.parameters)
499
500
        bkg_method = datasets_settings.background.method
501
502
        bkg_maker = None
503
        if bkg_method == "fov_background":
504
            log.debug(f"Creating FoVBackgroundMaker with arguments {bkg_maker_config}")
505
            bkg_maker = FoVBackgroundMaker(**bkg_maker_config)
506
        elif bkg_method == "ring":
507
            bkg_maker = RingBackgroundMaker(**bkg_maker_config)
508
            log.debug(f"Creating RingBackgroundMaker with arguments {bkg_maker_config}")
509
            if datasets_settings.geom.axes.energy.nbins > 1:
510
                raise ValueError(
511
                    "You need to define a single-bin energy geometry for your dataset."
512
                )
513
        elif bkg_method == "reflected":
514
            bkg_maker = ReflectedRegionsBackgroundMaker(**bkg_maker_config)
515
            log.debug(
516
                f"Creating ReflectedRegionsBackgroundMaker with arguments {bkg_maker_config}"
517
            )
518
        else:
519
            log.warning("No background maker set. Check configuration.")
520
        return bkg_maker
521
522
    def _map_making(self):
523
        """Make maps and datasets for 3d analysis"""
524
        datasets_settings = self.config.datasets
525
        offset_max = datasets_settings.geom.selection.offset_max
526
527
        log.info("Creating reference dataset and makers.")
528
        stacked = self._create_reference_dataset(name="stacked")
529
530
        maker = self._create_dataset_maker()
531
        maker_safe_mask = self._create_safe_mask_maker()
532
        bkg_maker = self._create_background_maker()
533
534
        makers = [maker, maker_safe_mask, bkg_maker]
535
        makers = [maker for maker in makers if maker is not None]
536
537
        log.info("Start the data reduction loop.")
538
539
        datasets_maker = DatasetsMaker(
540
            makers,
541
            stack_datasets=datasets_settings.stack,
542
            n_jobs=self.config.general.n_jobs,
543
            cutout_mode="trim",
544
            cutout_width=2 * offset_max,
545
        )
546
        self.datasets = datasets_maker.run(stacked, self.observations)
547
        # TODO: move progress bar to DatasetsMaker but how with multiprocessing ?
548
549
    def _spectrum_extraction(self):
550
        """Run all steps for the spectrum extraction."""
551
        log.info("Reducing spectrum datasets.")
552
        datasets_settings = self.config.datasets
553
        dataset_maker = self._create_dataset_maker()
554
        safe_mask_maker = self._create_safe_mask_maker()
555
        bkg_maker = self._create_background_maker()
556
557
        reference = self._create_reference_dataset()
558
559
        datasets = []
560
        for obs in progress_bar(self.observations, desc="Observations"):
561
            log.debug(f"Processing observation {obs.obs_id}")
562
            dataset = dataset_maker.run(reference.copy(), obs)
563
            if bkg_maker is not None:
564
                dataset = bkg_maker.run(dataset, obs)
565
                if dataset.counts_off is None:
566
                    log.debug(
567
                        f"No OFF region found for observation {obs.obs_id}. Discarding."
568
                    )
569
                    continue
570
            dataset = safe_mask_maker.run(dataset, obs)
571
            log.debug(dataset)
572
            datasets.append(dataset)
573
        self.datasets = Datasets(datasets)
574
575
        if datasets_settings.stack:
576
            stacked = self.datasets.stack_reduce(name="stacked")
577
            self.datasets = Datasets([stacked])
578
579
    @staticmethod
580
    def _make_energy_axis(axis, name="energy"):
581
        if axis.min is None or axis.max is None:
582
            return None
583
        elif axis.nbins is None or axis.nbins < 1:
584
            return None
585
        else:
586
            return MapAxis.from_bounds(
587
                name=name,
588
                lo_bnd=axis.min.value,
589
                hi_bnd=axis.max.to_value(axis.min.unit),
590
                nbin=axis.nbins,
591
                unit=axis.min.unit,
592
                interp="log",
593
                node_type="edges",
594
            )
595