Passed
Pull Request — main (#204)
by Chaitanya
01:26
created

asgardpy.config.operations.get_model_template()   A

Complexity

Conditions 3

Size

Total Lines 12
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 1
dl 0
loc 12
rs 10
c 0
b 0
f 0
1
"""
2
Main AsgardpyConfig Operations Module
3
"""
4
5
import logging
6
from collections.abc import Mapping
7
from pathlib import Path
8
9
import numpy as np
10
from gammapy.modeling.models import Models, SkyModel
11
12
__all__ = [
13
    "all_model_templates",
14
    "compound_model_dict_converstion",
15
    "get_model_template",
16
    "recursive_merge_dicts",
17
    "deep_update",
18
]
19
20
CONFIG_PATH = Path(__file__).resolve().parent
21
22
log = logging.getLogger(__name__)
23
24
25
def all_model_templates():
26
    """
27
    Collect all Template Models provided in Asgardpy, and their small tag names.
28
    """
29
    template_files = sorted(list(CONFIG_PATH.glob("model_templates/model_template*yaml")))
30
31
    all_tags = []
32
    for file in template_files:
33
        all_tags.append(file.name.split("_")[-1].split(".")[0])
34
    all_tags = np.array(all_tags)
35
36
    return all_tags, template_files
37
38
39
def get_model_template(spec_model_tag):
40
    """
41
    Read a particular template model yaml filename to create an AsgardpyConfig
42
    object.
43
    """
44
    all_tags, template_files = all_model_templates()
45
    new_model_file = None
46
47
    for file, tag in zip(template_files, all_tags, strict=True):
48
        if spec_model_tag == tag:
49
            new_model_file = file
50
    return new_model_file
51
52
53
def check_gammapy_model(gammapy_model):
54
    """
55
    For a given object type, try to read it as a Gammapy Models object.
56
    """
57
    if isinstance(gammapy_model, Models | SkyModel):
58
        models_gpy = Models(gammapy_model)
59
    else:
60
        try:
61
            models_gpy = Models.read(gammapy_model)
62
        except KeyError:
63
            raise TypeError("%s File cannot be read by Gammapy Models", gammapy_model) from KeyError
64
65
    return models_gpy
66
67
68
def recursive_merge_lists(final_config_key, extra_config_key, value):
69
    """
70
    Recursively merge from lists of dicts. Distinct function as an auxiliary for
71
    the recursive_merge_dicts function.
72
    """
73
    new_config = []
74
75
    for key_, value_ in zip(final_config_key, value, strict=False):
76
        key_ = recursive_merge_dicts(key_ or {}, value_)
77
        new_config.append(key_)
78
79
    # For example moving from a smaller list of model parameters to a
80
    # longer list.
81
    if len(final_config_key) < len(extra_config_key):
82
        for value_ in value[len(final_config_key) :]:
83
            new_config.append(value_)
84
    return new_config
85
86
87
def recursive_merge_dicts(base_config, extra_config):
88
    """
89
    Recursively merge two dictionaries.
90
    Entries in extra_config override entries in base_config. The built-in
91
    update function cannot be used for hierarchical dicts.
92
93
    Also for the case when there is a list of dicts involved, one has to be
94
    more careful. The extra_config may have longer list of dicts as compared
95
    with the base_config, in which case, the extra items are simply added to
96
    the merged final list.
97
98
    Combined here are 2 options from SO.
99
100
    See:
101
    http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356
102
    and also
103
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/18394648#18394648
104
105
    Parameters
106
    ----------
107
    base_config : dict
108
        dictionary to be merged
109
    extra_config : dict
110
        dictionary to be merged
111
    Returns
112
    -------
113
    final_config : dict
114
        merged dict
115
    """
116
    final_config = base_config.copy()
117
118
    for key, value in extra_config.items():
119
        if key in final_config and isinstance(final_config[key], list):
120
            final_config[key] = recursive_merge_lists(final_config[key], extra_config[key], value)
121
        elif key in final_config and isinstance(final_config[key], dict):
122
            final_config[key] = recursive_merge_dicts(final_config.get(key) or {}, value)
123
        else:
124
            final_config[key] = value
125
126
    return final_config
127
128
129
def deep_update(d, u):
130
    """
131
    Recursively update a nested dictionary.
132
133
    Just like in Gammapy, taken from: https://stackoverflow.com/a/3233356/19802442
134
    """
135
    for k, v in u.items():
136
        if isinstance(v, Mapping):
137
            d[k] = deep_update(d.get(k, {}), v)
138
        else:
139
            d[k] = v
140
    return d
141
142
143
def compound_model_dict_converstion(dict):
144
    """
145
    Given a Gammapy CompoundSpectralModel as a dict object, convert it into
146
    an Asgardpy form.
147
    """
148
    ebl_abs = dict["model2"]
149
    ebl_abs["alpha_norm"] = ebl_abs["parameters"][0]["value"]
150
    ebl_abs["redshift"] = ebl_abs["parameters"][1]["value"]
151
    ebl_abs.pop("parameters", None)
152
153
    dict["type"] = dict["model1"]["type"]
154
    dict["parameters"] = dict["model1"]["parameters"]
155
    dict["ebl_abs"] = ebl_abs
156
157
    dict.pop("model1", None)
158
    dict.pop("model2", None)
159
    dict.pop("operator", None)
160
161
    return dict
162