Passed
Pull Request — main (#157)
by Chaitanya
01:53
created

asgardpy.data.target   B

Complexity

Total Complexity 46

Size/Duplication

Total Lines 564
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 257
dl 0
loc 564
rs 8.72
c 0
b 0
f 0
wmc 46

How to fix   Complexity   

Complexity

Complex classes like asgardpy.data.target often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Classes containing the Target config parameters for the high-level interface and
3
also the functions involving Models generation and assignment to datasets.
4
"""
5
6
from enum import Enum
7
8
import astropy.units as u
9
import numpy as np
10
from gammapy.modeling import Parameter
11
from gammapy.modeling.models import (
12
    SPATIAL_MODEL_REGISTRY,
13
    SPECTRAL_MODEL_REGISTRY,
14
    DatasetModels,
15
    EBLAbsorptionNormSpectralModel,
16
    FoVBackgroundModel,
17
    Models,
18
    SkyModel,
19
    SpectralModel,
20
)
21
22
from asgardpy.base.base import AngleType, BaseConfig, FrameEnum, PathType
23
from asgardpy.base.geom import SkyPositionConfig
24
25
__all__ = [
26
    "BrokenPowerLaw2SpectralModel",
27
    "EBLAbsorptionModel",
28
    "ExpCutoffLogParabolaSpectralModel",
29
    "ModelTypeEnum",
30
    "RoISelectionConfig",
31
    "SpatialModelConfig",
32
    "SpectralModelConfig",
33
    "Target",
34
    "apply_selection_mask_to_models",
35
    "config_to_dict",
36
    "read_models_from_asgardpy_config",
37
    "set_models",
38
]
39
40
41
# Basic components to define the Target Config and any Models Config
42
class ModelTypeEnum(str, Enum):
43
    """
44
    Config section for Different Gammapy Model type.
45
    """
46
47
    skymodel = "SkyModel"
48
    fovbkgmodel = "FoVBackgroundModel"
49
50
51
class EBLAbsorptionModel(BaseConfig):
52
    """
53
    Config section for parameters to use for EBLAbsorptionNormSpectralModel.
54
    """
55
56
    filename: PathType = PathType("None")
57
    reference: str = ""
58
    type: str = "EBLAbsorptionNormSpectralModel"
59
    redshift: float = 0.0
60
    alpha_norm: float = 1.0
61
62
63
class ModelParams(BaseConfig):
64
    """Config section for parameters to use for a basic Parameter object."""
65
66
    name: str = ""
67
    value: float = 1
68
    unit: str = " "
69
    error: float = 0.1
70
    min: float = 0.1
71
    max: float = 10
72
    frozen: bool = True
73
74
75
class SpectralModelConfig(BaseConfig):
76
    """
77
    Config section for parameters to use for creating a SpectralModel object.
78
    """
79
80
    type: str = ""
81
    parameters: list[ModelParams] = [ModelParams()]
82
    ebl_abs: EBLAbsorptionModel = EBLAbsorptionModel()
83
84
85
class SpatialModelConfig(BaseConfig):
86
    """
87
    Config section for parameters to use for creating a SpatialModel object.
88
    """
89
90
    type: str = ""
91
    frame: FrameEnum = FrameEnum.icrs
92
    parameters: list[ModelParams] = [ModelParams()]
93
94
95
class ModelComponent(BaseConfig):
96
    """Config section for parameters to use for creating a SkyModel object."""
97
98
    name: str = ""
99
    type: ModelTypeEnum = ModelTypeEnum.skymodel
100
    datasets_names: list[str] = [""]
101
    spectral: SpectralModelConfig = SpectralModelConfig()
102
    spatial: SpatialModelConfig = SpatialModelConfig()
103
104
105
class RoISelectionConfig(BaseConfig):
106
    """
107
    Config section for parameters to perform some selection on FoV source
108
    models.
109
    """
110
111
    roi_radius: AngleType = 0 * u.deg
112
    free_sources: list[str] = []
113
114
115
class CatalogConfig(BaseConfig):
116
    """Config section for parameters to use information from Catalog."""
117
118
    name: str = ""
119
    selection_radius: AngleType = 0 * u.deg
120
    exclusion_radius: AngleType = 0 * u.deg
121
122
123
class Target(BaseConfig):
124
    """Config section for main information on creating various Models."""
125
126
    source_name: str = ""
127
    sky_position: SkyPositionConfig = SkyPositionConfig()
128
    use_uniform_position: bool = True
129
    models_file: PathType = PathType("None")
130
    datasets_with_fov_bkg_model: list[str] = []
131
    use_catalog: CatalogConfig = CatalogConfig()
132
    components: list[ModelComponent] = [ModelComponent()]
133
    covariance: str = ""
134
    from_3d: bool = False
135
    roi_selection: RoISelectionConfig = RoISelectionConfig()
136
137
138
class ExpCutoffLogParabolaSpectralModel(SpectralModel):
139
    r"""Spectral Exponential Cutoff Log Parabola model.
140
141
    Using a simple template from Gammapy.
142
143
    .. math::
144
        \phi(E) = \phi_0 \left( \frac{E}{E_0} \right) ^ {
145
          - \alpha_1 - \beta \log{ \left( \frac{E}{E_0} \right) }} \cdot
146
          \exp(- {(\lambda E})^{\alpha_2})
147
148
    Parameters
149
    ----------
150
    amplitude : `~astropy.units.Quantity`
151
        :math:`\phi_0`
152
    reference : `~astropy.units.Quantity`
153
        :math:`E_0`
154
    alpha_1 : `~astropy.units.Quantity`
155
        :math:`\alpha_1`
156
    beta : `~astropy.units.Quantity`
157
        :math:`\beta`
158
    lambda_ : `~astropy.units.Quantity`
159
        :math:`\lambda`
160
    alpha_2 : `~astropy.units.Quantity`
161
        :math:`\alpha_2`
162
    """
163
164
    tag = ["ExpCutoffLogParabolaSpectralModel", "ECLP"]
165
166
    amplitude = Parameter(
167
        "amplitude",
168
        "1e-12 cm-2 s-1 TeV-1",
169
        scale_method="scale10",
170
        interp="log",
171
        is_norm=True,
172
    )
173
    reference = Parameter("reference", "1 TeV", frozen=True)
174
    alpha_1 = Parameter("alpha_1", -2)
175
    alpha_2 = Parameter("alpha_2", 1, frozen=True)
176
    beta = Parameter("beta", 1)
177
    lambda_ = Parameter("lambda_", "0.1 TeV-1")
178
179
    @staticmethod
180
    def evaluate(energy, amplitude, reference, alpha_1, beta, lambda_, alpha_2):
181
        """Evaluate the model (static function)."""
182
        en_ref = energy / reference
183
        exponent = -alpha_1 - beta * np.log(en_ref)
184
        cutoff = np.exp(-np.power(energy * lambda_, alpha_2))
185
186
        return amplitude * np.power(en_ref, exponent) * cutoff
187
188
189
class BrokenPowerLaw2SpectralModel(SpectralModel):
190
    r"""Spectral broken power-law 2 model.
191
192
    In this slightly modified Broken Power Law, instead of having the second index
193
    as a distinct parameter, the difference in the spectral indices around the
194
    Break Energy is used, to try for some assumptions on the different physical
195
    processes that define the full spectrum, where the second process is dependent
196
    on the first process.
197
198
    For more information see :ref:`broken-powerlaw-spectral-model`.
199
200
    .. math::
201
        \phi(E) = \phi_0 \cdot \begin{cases}
202
                \left( \frac{E}{E_{break}} \right)^{-\Gamma_1} & \text{if } E < E_{break} \\
203
                \left( \frac{E}{E_{break}} \right)^{-(\Gamma_1 + \Delta\Gamma)} & \text{otherwise}
204
            \end{cases}
205
206
    Parameters
207
    ----------
208
    index1 : `~astropy.units.Quantity`
209
        :math:`\Gamma_1`
210
    index_diff : `~astropy.units.Quantity`
211
        :math:`\Delta\Gamma`
212
    amplitude : `~astropy.units.Quantity`
213
        :math:`\phi_0`
214
    ebreak : `~astropy.units.Quantity`
215
        :math:`E_{break}`
216
217
    See Also
218
    --------
219
    SmoothBrokenPowerLawSpectralModel
220
    """
221
222
    tag = ["BrokenPowerLaw2SpectralModel", "bpl2"]
223
    index1 = Parameter("index1", 2.0)
224
    index_diff = Parameter("index_diff", 1.0)
225
    amplitude = Parameter(
226
        name="amplitude",
227
        value="1e-12 cm-2 s-1 TeV-1",
228
        scale_method="scale10",
229
        interp="log",
230
        is_norm=True,
231
    )
232
    ebreak = Parameter("ebreak", "1 TeV")
233
234
    @staticmethod
235
    def evaluate(energy, index1, index_diff, amplitude, ebreak):
236
        """Evaluate the model (static function)."""
237
        energy = np.atleast_1d(energy)
238
        cond = energy < ebreak
239
        bpwl2 = amplitude * np.ones(energy.shape)
240
241
        index2 = index1 + index_diff
242
        bpwl2[cond] *= (energy[cond] / ebreak) ** (-index1)
243
        bpwl2[~cond] *= (energy[~cond] / ebreak) ** (-index2)
244
245
        return bpwl2
246
247
248
SPECTRAL_MODEL_REGISTRY.append(ExpCutoffLogParabolaSpectralModel)
249
SPECTRAL_MODEL_REGISTRY.append(BrokenPowerLaw2SpectralModel)
250
251
252
# Function for Models assignment
253
def extend_bkg_models(models, all_datasets, datasets_with_fov_bkg_model):
254
    """ """
255
    if len(datasets_with_fov_bkg_model) > 0:
256
        # For extending a Background Model for a given 3D dataset name
257
        bkg_models = []
258
259
        for dataset in all_datasets:
260
            if dataset.name in datasets_with_fov_bkg_model:
261
                # Check if it is of MapDataset or MapDatasetOnOff type and
262
                # no associated background model exists already.
263
                if "MapDataset" in dataset.tag and dataset.background_model is None:
264
                    bkg_models.append(FoVBackgroundModel(dataset_name=f"{dataset.name}-bkg"))
265
266
        models.extend(bkg_models)
267
268
    return models
269
270
271
def update_models_datasets_names(models, source_name, datasets_name_list):
272
    """ """
273
    if len(models) > 1:
274
        models[source_name].datasets_names = datasets_name_list
275
276
        # Check if FoVBackgroundModel is provided
277
        bkg_model_name = [m_name for m_name in models.names if "bkg" in m_name]
278
279
        if len(bkg_model_name) > 0:
280
            # Remove the -bkg part of the name of the FoVBackgroundModel to get
281
            # the appropriate datasets name.
282
            for bkg_name in bkg_model_name:
283
                models[bkg_name].datasets_names = [bkg_name[:-4]]
284
    else:
285
        models.datasets_names = datasets_name_list
286
287
    return models
288
289
290
def set_models(
291
    config_target,
292
    datasets,
293
    datasets_name_list=None,
294
    models=None,
295
):
296
    """
297
    Set models on given Datasets.
298
299
    Parameters
300
    ----------
301
    config_target: `AsgardpyConfig.target`
302
        AsgardpyConfig containing target information.
303
    datasets: `gammapy.datasets.Datasets`
304
        Datasets object
305
    dataset_name_list: list
306
        List of datasets_names to be used on the Models, before assigning them
307
        to the given datasets.
308
    models : `~gammapy.modeling.models.Models` or str of file location or None
309
        Models object or YAML models string
310
311
    Returns
312
    -------
313
    datasets: `gammapy.datasets.Datasets`
314
        Datasets object with Models assigned.
315
    """
316
    # Have some checks on argument types
317
    if isinstance(models, DatasetModels | list):
318
        models = Models(models)
319
    elif isinstance(models, PathType):
320
        models = Models.read(models)
321
    elif models is None:
322
        models = Models(models)
323
    else:
324
        raise TypeError(f"Invalid type: {type(models)}")
325
326
    if len(models) == 0:
327
        if config_target.components[0].name != "":
328
            models = read_models_from_asgardpy_config(config_target)
329
        else:
330
            raise ValueError("No input for Models provided for the Target source!")
331
    else:
332
        models = apply_selection_mask_to_models(
333
            list_sources=models,
334
            target_source=config_target.source_name,
335
            roi_radius=config_target.roi_selection.roi_radius,
336
            free_sources=config_target.roi_selection.free_sources,
337
        )
338
339
    models = extend_bkg_models(models, datasets, config_target.datasets_with_fov_bkg_model)
340
341
    if datasets_name_list is None:
342
        datasets_name_list = datasets.names
343
344
    models = update_models_datasets_names(models, config_target.source_name, datasets_name_list)
345
346
    datasets.models = models
347
348
    return datasets, models
349
350
351
def apply_models_mask_in_roi(list_sources_excluded, target_source, roi_radius, free_sources):
352
    """ """
353
    # Get the target source position as the center of RoI
354
    if not target_source:
355
        target_source = list_sources_excluded[0].name
356
        target_source_pos = list_sources_excluded[0].spatial_model.position
357
    else:
358
        target_source_pos = list_sources_excluded[target_source].spatial_model.position
359
360
    # If RoI radius is provided and is not default
361
    if roi_radius.to_value("deg") != 0:
362
        for model_ in list_sources_excluded:
363
            model_pos = model_.spatial_model.position
364
            separation = target_source_pos.separation(model_pos).deg
365
            if separation >= roi_radius.deg:
366
                model_.spectral_model.freeze()
367
    else:
368
        # For a given list of non free sources, unfreeze the spectral amplitude
369
        if len(free_sources) > 0:
370
            for model_ in list_sources_excluded:
371
                # Freeze all spectral parameters for other models
372
                if model_.name != target_source:
373
                    model_.spectral_model.freeze()
374
                # and now unfreeze the amplitude of selected models
375
                if model_.name in free_sources:
376
                    model_.spectral_model.parameters["amplitude"].frozen = False
377
378
    return list_sources_excluded
379
380
381
def apply_selection_mask_to_models(
382
    list_sources, target_source=None, selection_mask=None, roi_radius=0 * u.deg, free_sources=None
383
):
384
    """
385
    For a given list of source models, with a given target source, apply various
386
    selection masks on the Region of Interest in the sky. This will lead to
387
    complete exclusion of models or freezing some or all spectral parameters.
388
    These selections excludes the diffuse background models in the given list.
389
390
    First priority is given if a distinct selection mask is provided, with a
391
    list of excluded regions to return only the source models within the selected
392
    ROI.
393
394
    Second priority is on creating a Circular ROI as per the given radius, and
395
    freeze all the spectral parameters of the models of the sources.
396
397
    Third priority is when a list of free_sources is provided, then the
398
    spectral amplitude of models of those sources, if present in the given list
399
    of models, will be unfrozen, and hence, allowed to be variable for fitting.
400
401
    Parameters
402
    ----------
403
    list_sources: '~gammapy.modeling.models.Models'
404
        Models object containing a list of source models.
405
    target_source: 'str'
406
        Name of the target source, whose position is used as the center of ROI.
407
    selection_mask: 'WcsNDMap'
408
        Map containing a boolean mask to apply to Models object.
409
    roi_radius: 'astropy.units.Quantity' or 'asgardpy.data.base.AngleType'
410
        Radius for a circular region around ROI (deg)
411
    free_sources: 'list'
412
        List of source names for which the spectral amplitude is to be kept free.
413
414
    Returns
415
    -------
416
    list_sources: '~gammapy.modeling.models.Models'
417
        Selected Models object.
418
    """
419
    list_sources_excluded = []
420
    list_diffuse = []
421
422
    if free_sources is None:
423
        free_sources = []
424
425
    # Separate the list of sources and diffuse background
426
    for model_ in list_sources:
427
        if "diffuse" in model_.name or "bkg" in model_.name:
428
            list_diffuse.append(model_)
429
        else:
430
            list_sources_excluded.append(model_)
431
432
    list_sources_excluded = Models(list_sources_excluded)
433
434
    # If a distinct selection mask is provided
435
    if selection_mask:
436
        list_sources_excluded = list_sources_excluded.select_mask(selection_mask)
437
438
    list_sources_excluded = apply_models_mask_in_roi(
439
        list_sources_excluded, target_source, roi_radius, free_sources
440
    )
441
442
    # Add the diffuse background models back
443
    for diff_ in list_diffuse:
444
        list_sources_excluded.append(diff_)
445
446
    # Re-assign to the main variable
447
    list_sources = list_sources_excluded
448
449
    return list_sources
450
451
452
# Functions for Models generation
453
def add_ebl_model_from_config(spectral_model, model_config):
454
    """ """
455
    ebl_model = model_config.spectral.ebl_abs
456
457
    # First check for filename of a custom EBL model
458
    if ebl_model.filename.is_file():
459
        model2 = EBLAbsorptionNormSpectralModel.read(str(ebl_model.filename), redshift=ebl_model.redshift)
460
        # Update the reference name when using the custom EBL model for logging
461
        ebl_model.reference = ebl_model.filename.name[:-8].replace("-", "_")
462
    else:
463
        model2 = EBLAbsorptionNormSpectralModel.read_builtin(ebl_model.reference, redshift=ebl_model.redshift)
464
    if ebl_model.alpha_norm:
465
        model2.alpha_norm.value = ebl_model.alpha_norm
466
467
    spectral_model *= model2
468
469
    return spectral_model
470
471
472
def read_models_from_asgardpy_config(config):
473
    """
474
    Reading Models information from AsgardpyConfig and return Models object.
475
    If a FoVBackgroundModel information is provided, it will also be added.
476
477
    Parameter
478
    ---------
479
    config: `AsgardpyConfig`
480
        Config section containing Target source information
481
482
    Returns
483
    -------
484
    models: `gammapy.modeling.models.Models`
485
        Full gammapy Models object.
486
    """
487
    models = Models()
488
489
    for model_config in config.components:
490
        # Spectral Model
491
        spectral_model = SPECTRAL_MODEL_REGISTRY.get_cls(model_config.spectral.type)().from_dict(
492
            {"spectral": config_to_dict(model_config.spectral)}
493
        )
494
        if model_config.spectral.ebl_abs.reference != "":
495
            spectral_model = add_ebl_model_from_config(spectral_model, model_config)
496
497
        spectral_model.name = config.source_name
498
499
        # Spatial model if provided
500
        if model_config.spatial.type != "":
501
            spatial_model = SPATIAL_MODEL_REGISTRY.get_cls(model_config.spatial.type)().from_dict(
502
                {"spatial": config_to_dict(model_config.spatial)}
503
            )
504
        else:
505
            spatial_model = None
506
507
        match model_config.type:
508
            case "SkyModel":
509
                models.append(
510
                    SkyModel(
511
                        spectral_model=spectral_model,
512
                        spatial_model=spatial_model,
513
                        name=config.source_name,
514
                    )
515
                )
516
517
            # FoVBackgroundModel is the second component
518
            case "FoVBackgroundModel":
519
                models.append(
520
                    FoVBackgroundModel(
521
                        spectral_model=spectral_model,
522
                        spatial_model=spatial_model,
523
                        dataset_name=model_config.datasets_names[0],
524
                    )
525
                )
526
527
    return models
528
529
530
def config_to_dict(model_config):
531
    """
532
    Convert the Spectral/Spatial models into dict.
533
    Probably an extra step and maybe removed later.
534
535
    Parameter
536
    ---------
537
    model_config: `AsgardpyConfig`
538
        Config section containing Target Model SkyModel components only.
539
540
    Returns
541
    -------
542
    model_dict: dict
543
        dictionary of the particular model.
544
    """
545
    model_dict = {}
546
    model_dict["type"] = str(model_config.type)
547
    model_dict["parameters"] = []
548
549
    for par in model_config.parameters:
550
        par_dict = {}
551
        par_dict["name"] = par.name
552
        par_dict["value"] = par.value
553
        par_dict["unit"] = par.unit
554
        par_dict["error"] = par.error
555
        par_dict["min"] = par.min
556
        par_dict["max"] = par.max
557
        par_dict["frozen"] = par.frozen
558
        model_dict["parameters"].append(par_dict)
559
560
    # For spatial model, include frame info
561
    if hasattr(model_config, "frame"):
562
        model_dict["frame"] = model_config.frame
563
564
    return model_dict
565