Passed
Pull Request — main (#178)
by Chaitanya
05:19
created

asgardpy.config.generator.AsgardpyConfig.write()   A

Complexity

Conditions 3

Size

Total Lines 8
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 5
nop 3
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""
2
Main AsgardpyConfig Generator Module
3
"""
4
5
import json
6
import logging
7
from collections.abc import Mapping
8
from enum import Enum
9
from pathlib import Path
10
11
import numpy as np
12
import yaml
13
from gammapy.modeling.models import Models
14
from gammapy.utils.scripts import make_path, read_yaml
15
16
from asgardpy.analysis.step_base import AnalysisStepEnum
17
from asgardpy.base import BaseConfig, PathType
18
from asgardpy.data import (
19
    Dataset1DConfig,
20
    Dataset3DConfig,
21
    FitConfig,
22
    FluxPointsConfig,
23
    Target,
24
)
25
26
__all__ = [
27
    "all_model_templates",
28
    "AsgardpyConfig",
29
    "GeneralConfig",
30
    "gammapy_to_asgardpy_model_config",
31
    "get_model_template",
32
    "recursive_merge_dicts",
33
]
34
35
CONFIG_PATH = Path(__file__).resolve().parent
36
37
log = logging.getLogger(__name__)
38
39
40
# Other general config params
41
class LogConfig(BaseConfig):
42
    """Config section for main logging information."""
43
44
    level: str = "info"
45
    filename: str = ""
46
    filemode: str = "w"
47
    format: str = ""
48
    datefmt: str = ""
49
50
51
class ParallelBackendEnum(str, Enum):
52
    """Config section for list of parallel processing backend methods."""
53
54
    multi = "multiprocessing"
55
    ray = "ray"
56
57
58
class GeneralConfig(BaseConfig):
59
    """Config section for general information for running AsgardpyAnalysis."""
60
61
    log: LogConfig = LogConfig()
62
    outdir: PathType = "None"
63
    n_jobs: int = 1
64
    parallel_backend: ParallelBackendEnum = ParallelBackendEnum.multi
65
    steps: list[AnalysisStepEnum] = []
66
    overwrite: bool = True
67
    stacked_dataset: bool = False
68
69
70
def all_model_templates():
71
    """
72
    Collect all Template Models provided in Asgardpy, and their small tag names.
73
    """
74
    template_files = sorted(list(CONFIG_PATH.glob("model_templates/model_template*yaml")))
75
76
    all_tags = []
77
    for file in template_files:
78
        all_tags.append(file.name.split("_")[-1].split(".")[0])
79
    all_tags = np.array(all_tags)
80
81
    return all_tags, template_files
82
83
84
def get_model_template(spec_model_tag):
85
    """
86
    Read a particular template model yaml filename to create an AsgardpyConfig
87
    object.
88
    """
89
    all_tags, template_files = all_model_templates()
90
    new_model_file = None
91
92
    for file, tag in zip(template_files, all_tags, strict=True):
93
        if spec_model_tag == tag:
94
            new_model_file = file
95
    return new_model_file
96
97
98
def recursive_merge_dicts(base_config, extra_config):
99
    """
100
    recursively merge two dictionaries.
101
    Entries in extra_config override entries in base_config. The built-in
102
    update function cannot be used for hierarchical dicts.
103
104
    Also for the case when there is a list of dicts involved, one has to be
105
    more careful. The extra_config may have longer list of dicts as compared
106
    with the base_config, in which case, the extra items are simply added to
107
    the merged final list.
108
109
    Combined here are 2 options from SO.
110
111
    See:
112
    http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356
113
    and also
114
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/18394648#18394648
115
116
    Parameters
117
    ----------
118
    base_config : dict
119
        dictionary to be merged
120
    extra_config : dict
121
        dictionary to be merged
122
    Returns
123
    -------
124
    final_config : dict
125
        merged dict
126
    """
127
    final_config = base_config.copy()
128
129
    for key, value in extra_config.items():
130
        if key in final_config and isinstance(final_config[key], list):
131
            new_config = []
132
133
            for key_, value_ in zip(final_config[key], value, strict=False):
134
                key_ = recursive_merge_dicts(key_ or {}, value_)
135
                new_config.append(key_)
136
137
            # For example moving from a smaller list of model parameters to a
138
            # longer list.
139
            if len(final_config[key]) < len(extra_config[key]):
140
                for value_ in value[len(final_config[key]) :]:
141
                    new_config.append(value_)
142
            final_config[key] = new_config
143
144
        elif key in final_config and isinstance(final_config[key], dict):
145
            final_config[key] = recursive_merge_dicts(final_config.get(key) or {}, value)
146
        else:
147
            final_config[key] = value
148
149
    return final_config
150
151
152
def deep_update(d, u):
153
    """
154
    Recursively update a nested dictionary.
155
156
    Just like in Gammapy, taken from: https://stackoverflow.com/a/3233356/19802442
157
    """
158
    for k, v in u.items():
159
        if isinstance(v, Mapping):
160
            d[k] = deep_update(d.get(k, {}), v)
161
        else:
162
            d[k] = v
163
    return d
164
165
166
def gammapy_to_asgardpy_model_config(gammapy_model, asgardpy_config_file=None, recursive_merge=True):
167
    """
168
    Read the Gammapy Models YAML file and save it as AsgardpyConfig object.
169
170
    Return
171
    ------
172
    asgardpy_config: `asgardpy.config.generator.AsgardpyConfig`
173
        Updated AsgardpyConfig object
174
    """
175
    try:
176
        models_gpy = Models.read(gammapy_model)
177
    except KeyError:
178
        log.error("%s File cannot be read by Gammapy Models", gammapy_model)
179
        return None
180
181
    if not asgardpy_config_file:
182
        asgardpy_config = AsgardpyConfig()  # Default object
183
    elif isinstance(asgardpy_config_file, str):  # File path
184
        asgardpy_config = AsgardpyConfig.read(asgardpy_config_file)
185
    elif isinstance(asgardpy_config_file, AsgardpyConfig):
186
        asgardpy_config = asgardpy_config_file
187
188
    models_gpy_dict = models_gpy.to_dict()
189
    asgardpy_config_target_dict = asgardpy_config.model_dump()["target"]
190
191
    if recursive_merge:
192
        temp_target_dict = recursive_merge_dicts(asgardpy_config_target_dict, models_gpy_dict)
193
    else:
194
        # Use when there are nans present in the other config file, which are
195
        # the defaults in Gammapy, but NOT in Asgardpy.
196
        # E.g. test data Fermi-3fhl-crab model file
197
        temp_target_dict = deep_update(asgardpy_config_target_dict, models_gpy_dict)
198
    asgardpy_config.target = temp_target_dict
199
200
    return asgardpy_config
201
202
203
# Combine everything!
204
class AsgardpyConfig(BaseConfig):
205
    """
206
    Asgardpy analysis configuration, based on Gammapy Analysis Config.
207
    """
208
209
    general: GeneralConfig = GeneralConfig()
210
211
    target: Target = Target()
212
213
    dataset3d: Dataset3DConfig = Dataset3DConfig()
214
    dataset1d: Dataset1DConfig = Dataset1DConfig()
215
216
    fit_params: FitConfig = FitConfig()
217
    flux_points_params: FluxPointsConfig = FluxPointsConfig()
218
219
    def __str__(self):
220
        """
221
        Display settings in pretty YAML format.
222
        """
223
        info = self.__class__.__name__ + "\n\n\t"
224
        data = self.to_yaml()
225
        data = data.replace("\n", "\n\t")
226
        info += data
227
        return info.expandtabs(tabsize=4)
228
229
    @classmethod
230
    def read(cls, path):
231
        """
232
        Reads from YAML file.
233
        """
234
        config = read_yaml(path)
235
        return AsgardpyConfig(**config)
236
237
    @classmethod
238
    def from_yaml(cls, config_str):
239
        """
240
        Create from YAML string.
241
        """
242
        settings = yaml.safe_load(config_str)
243
        return AsgardpyConfig(**settings)
244
245
    def write(self, path, overwrite=False):
246
        """
247
        Write to YAML file.
248
        """
249
        path = make_path(path)
250
        if path.exists() and not overwrite:
251
            raise OSError(f"File exists already: {path}")
252
        path.write_text(self.to_yaml())
253
254
    def to_yaml(self):
255
        """
256
        Convert to YAML string.
257
        """
258
        # Here using `dict()` instead of `json()` would be more natural.
259
        # We should change this once pydantic adds support for custom encoders
260
        # to `dict()`. See https://github.com/samuelcolvin/pydantic/issues/1043
261
        data = json.loads(self.model_dump_json())
262
        return yaml.dump(data, sort_keys=False, indent=4, width=80, default_flow_style=None)
263
264
    def set_logging(self):
265
        """
266
        Set logging config.
267
        Calls ``logging.basicConfig``, i.e. adjusts global logging state.
268
        """
269
        self.general.log.level = self.general.log.level.upper()
270
        logging.basicConfig(**self.general.log.model_dump())
271
        log.info("Setting logging config: %s", self.general.log.model_dump())
272
273
    def update(self, config=None, merge_recursive=False):
274
        """
275
        Update config with provided settings.
276
        Parameters
277
        ----------
278
        config : string dict or `AsgardpyConfig` object
279
            The other configuration settings provided in dict() syntax.
280
        merge_recursive : bool
281
            Perform a recursive merge from the other config onto the parent config.
282
283
        Returns
284
        -------
285
        config : `AsgardpyConfig` object
286
            Updated config object.
287
        """
288
        if isinstance(config, str):
289
            other = AsgardpyConfig.from_yaml(config)
290
        elif isinstance(config, AsgardpyConfig):
291
            other = config
292
        else:
293
            raise TypeError(f"Invalid type: {config}")
294
295
        # Special case of when only updating target model parameters from a
296
        # separate file, where the name of the source is not provided.
297
        if other.target.components[0].name == "":
298
            merge_recursive = True
299
300
        if merge_recursive:
301
            config_new = recursive_merge_dicts(
302
                self.model_dump(exclude_defaults=True), other.model_dump(exclude_defaults=True)
303
            )
304
        else:
305
            config_new = deep_update(
306
                self.model_dump(exclude_defaults=True), other.model_dump(exclude_defaults=True)
307
            )
308
        return AsgardpyConfig(**config_new)
309