Passed
Pull Request — main (#179)
by Chaitanya
02:04
created

asgardpy.config.generator   B

Complexity

Total Complexity 44

Size/Duplication

Total Lines 407
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 199
dl 0
loc 407
rs 8.8798
c 0
b 0
f 0
wmc 44

8 Functions

Rating   Name   Duplication   Size   Complexity  
A get_model_template() 0 12 3
A all_model_templates() 0 12 2
C recursive_merge_dicts() 0 52 9
A deep_update() 0 12 3
A check_config() 0 15 4
B write_asgardpy_model_to_file() 0 45 5
A gammapy_model_to_asgardpy_model_config() 0 57 4
A check_gammapy_model() 0 13 3

7 Methods

Rating   Name   Duplication   Size   Complexity  
A AsgardpyConfig.read() 0 7 1
A AsgardpyConfig.__str__() 0 9 1
A AsgardpyConfig.to_yaml() 0 9 1
A AsgardpyConfig.update() 0 31 3
A AsgardpyConfig.write() 0 8 3
A AsgardpyConfig.set_logging() 0 8 1
A AsgardpyConfig.from_yaml() 0 7 1

How to fix   Complexity   

Complexity

Complex classes like asgardpy.config.generator often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Main AsgardpyConfig Generator Module
3
"""
4
5
import json
6
import logging
7
import os
8
from collections.abc import Mapping
9
from enum import Enum
10
from pathlib import Path
11
12
import numpy as np
13
import yaml
14
from gammapy.modeling.models import CompoundSpectralModel, Models, SkyModel
15
from gammapy.utils.scripts import make_path, read_yaml
16
17
from asgardpy.analysis.step_base import AnalysisStepEnum
18
from asgardpy.base import BaseConfig, PathType
19
from asgardpy.data import (
20
    Dataset1DConfig,
21
    Dataset3DConfig,
22
    FitConfig,
23
    FluxPointsConfig,
24
    Target,
25
)
26
27
__all__ = [
28
    "all_model_templates",
29
    "AsgardpyConfig",
30
    "GeneralConfig",
31
    "gammapy_model_to_asgardpy_model_config",
32
    "get_model_template",
33
    "recursive_merge_dicts",
34
    "write_asgardpy_model_to_file",
35
]
36
37
CONFIG_PATH = Path(__file__).resolve().parent
38
39
log = logging.getLogger(__name__)
40
41
42
# Other general config params
43
class LogConfig(BaseConfig):
44
    """Config section for main logging information."""
45
46
    level: str = "info"
47
    filename: str = ""
48
    filemode: str = "w"
49
    format: str = ""
50
    datefmt: str = ""
51
52
53
class ParallelBackendEnum(str, Enum):
54
    """Config section for list of parallel processing backend methods."""
55
56
    multi = "multiprocessing"
57
    ray = "ray"
58
59
60
class GeneralConfig(BaseConfig):
61
    """Config section for general information for running AsgardpyAnalysis."""
62
63
    log: LogConfig = LogConfig()
64
    outdir: PathType = "None"
65
    n_jobs: int = 1
66
    parallel_backend: ParallelBackendEnum = ParallelBackendEnum.multi
67
    steps: list[AnalysisStepEnum] = []
68
    overwrite: bool = True
69
    stacked_dataset: bool = False
70
71
72
def all_model_templates():
73
    """
74
    Collect all Template Models provided in Asgardpy, and their small tag names.
75
    """
76
    template_files = sorted(list(CONFIG_PATH.glob("model_templates/model_template*yaml")))
77
78
    all_tags = []
79
    for file in template_files:
80
        all_tags.append(file.name.split("_")[-1].split(".")[0])
81
    all_tags = np.array(all_tags)
82
83
    return all_tags, template_files
84
85
86
def get_model_template(spec_model_tag):
87
    """
88
    Read a particular template model yaml filename to create an AsgardpyConfig
89
    object.
90
    """
91
    all_tags, template_files = all_model_templates()
92
    new_model_file = None
93
94
    for file, tag in zip(template_files, all_tags, strict=True):
95
        if spec_model_tag == tag:
96
            new_model_file = file
97
    return new_model_file
98
99
100
def check_config(config):
101
    """
102
    For a given object type, try to read it as an AsgardpyConfig object.
103
    """
104
    if isinstance(config, str | Path):
105
        if Path(config).is_file():
106
            AConfig = AsgardpyConfig.read(config)
107
        else:
108
            AConfig = AsgardpyConfig.from_yaml(config)
109
    elif isinstance(config, AsgardpyConfig):
110
        AConfig = config
111
    else:
112
        raise TypeError(f"Invalid type: {config}")
113
114
    return AConfig
115
116
117
def check_gammapy_model(gammapy_model):
118
    """
119
    For a given object type, try to read it as a Gammapy Models object.
120
    """
121
    if isinstance(gammapy_model, Models | SkyModel):
122
        models_gpy = Models(gammapy_model)
123
    else:
124
        try:
125
            models_gpy = Models.read(gammapy_model)
126
        except KeyError:
127
            raise TypeError("%s File cannot be read by Gammapy Models", gammapy_model) from KeyError
128
129
    return models_gpy
130
131
132
def recursive_merge_dicts(base_config, extra_config):
133
    """
134
    recursively merge two dictionaries.
135
    Entries in extra_config override entries in base_config. The built-in
136
    update function cannot be used for hierarchical dicts.
137
138
    Also for the case when there is a list of dicts involved, one has to be
139
    more careful. The extra_config may have longer list of dicts as compared
140
    with the base_config, in which case, the extra items are simply added to
141
    the merged final list.
142
143
    Combined here are 2 options from SO.
144
145
    See:
146
    http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356
147
    and also
148
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/18394648#18394648
149
150
    Parameters
151
    ----------
152
    base_config : dict
153
        dictionary to be merged
154
    extra_config : dict
155
        dictionary to be merged
156
    Returns
157
    -------
158
    final_config : dict
159
        merged dict
160
    """
161
    final_config = base_config.copy()
162
163
    for key, value in extra_config.items():
164
        if key in final_config and isinstance(final_config[key], list):
165
            new_config = []
166
167
            for key_, value_ in zip(final_config[key], value, strict=False):
168
                key_ = recursive_merge_dicts(key_ or {}, value_)
169
                new_config.append(key_)
170
171
            # For example moving from a smaller list of model parameters to a
172
            # longer list.
173
            if len(final_config[key]) < len(extra_config[key]):
174
                for value_ in value[len(final_config[key]) :]:
175
                    new_config.append(value_)
176
            final_config[key] = new_config
177
178
        elif key in final_config and isinstance(final_config[key], dict):
179
            final_config[key] = recursive_merge_dicts(final_config.get(key) or {}, value)
180
        else:
181
            final_config[key] = value
182
183
    return final_config
184
185
186
def deep_update(d, u):
187
    """
188
    Recursively update a nested dictionary.
189
190
    Just like in Gammapy, taken from: https://stackoverflow.com/a/3233356/19802442
191
    """
192
    for k, v in u.items():
193
        if isinstance(v, Mapping):
194
            d[k] = deep_update(d.get(k, {}), v)
195
        else:
196
            d[k] = v
197
    return d
198
199
200
def gammapy_model_to_asgardpy_model_config(gammapy_model, asgardpy_config_file=None, recursive_merge=True):
201
    """
202
    Read the Gammapy Models object and save it as AsgardpyConfig object.
203
204
    The gammapy_model object may be a YAML config filename/path/object or a
205
    Gammapy Models object itself.
206
207
    Return
208
    ------
209
    asgardpy_config: `asgardpy.config.generator.AsgardpyConfig`
210
        Updated AsgardpyConfig object
211
    """
212
213
    models_gpy = check_gammapy_model(gammapy_model)
214
215
    models_gpy_dict = models_gpy.to_dict()
216
217
    if not asgardpy_config_file:
218
        asgardpy_config = AsgardpyConfig()  # Default object
219
        # Remove any name values in the model dict
220
        models_gpy_dict["components"][0].pop("datasets_names", None)
221
        models_gpy_dict["components"][0].pop("name", None)
222
    else:
223
        asgardpy_config = check_config(asgardpy_config_file)
224
225
    # For EBL part only
226
    if "model1" in models_gpy_dict["components"][0]["spectral"].keys():
227
        ebl_abs = models_gpy_dict["components"][0]["spectral"]["model2"]
228
        ebl_abs["alpha_norm"] = ebl_abs["parameters"][0]["value"]
229
        ebl_abs["redshift"] = ebl_abs["parameters"][1]["value"]
230
        ebl_abs.pop("parameters", None)
231
232
        models_gpy_dict["components"][0]["spectral"]["type"] = models_gpy_dict["components"][0]["spectral"][
233
            "model1"
234
        ]["type"]
235
        models_gpy_dict["components"][0]["spectral"]["parameters"] = models_gpy_dict["components"][0]["spectral"][
236
            "model1"
237
        ]["parameters"]
238
        models_gpy_dict["components"][0]["spectral"]["ebl_abs"] = ebl_abs
239
240
        models_gpy_dict["components"][0]["spectral"].pop("model1", None)
241
        models_gpy_dict["components"][0]["spectral"].pop("model2", None)
242
        models_gpy_dict["components"][0]["spectral"].pop("operator", None)
243
244
    asgardpy_config_target_dict = asgardpy_config.model_dump()["target"]
245
246
    if recursive_merge:
247
        temp_target_dict = recursive_merge_dicts(asgardpy_config_target_dict, models_gpy_dict)
248
    else:
249
        # Use when there are nans present in the other config file, which are
250
        # the defaults in Gammapy, but NOT in Asgardpy.
251
        # E.g. test data Fermi-3fhl-crab model file
252
        temp_target_dict = deep_update(asgardpy_config_target_dict, models_gpy_dict)
253
254
    asgardpy_config.target = temp_target_dict
255
256
    return asgardpy_config
257
258
259
def write_asgardpy_model_to_file(gammapy_model, output_file=None, recursive_merge=True):
260
    """
261
    Read the Gammapy Models object and save it as AsgardpyConfig YAML file
262
    containing only the Model parameters, similar to the model templates
263
    available.
264
    """
265
    gammapy_model = check_gammapy_model(gammapy_model)
266
267
    asgardpy_config = gammapy_model_to_asgardpy_model_config(
268
        gammapy_model=gammapy_model[0],
269
        asgardpy_config_file=None,
270
        recursive_merge=recursive_merge,
271
    )
272
273
    if not output_file:
274
        if isinstance(gammapy_model[0].spectral_model, CompoundSpectralModel):
275
            model_tag = gammapy_model[0].spectral_model.model1.tag[1] + "_ebl"
276
        else:
277
            model_tag = gammapy_model[0].spectral_model.tag[1]
278
279
        output_file = CONFIG_PATH / f"model_templates/model_template_{model_tag}.yaml"
280
        os.path.expandvars(output_file)
281
    else:
282
        if not isinstance(output_file, Path):
283
            output_file = Path(os.path.expandvars(output_file))
284
285
    temp_ = asgardpy_config.model_dump(exclude_defaults=True)
286
    temp_["target"].pop("models_file", None)
287
288
    if isinstance(gammapy_model[0].spectral_model, CompoundSpectralModel):
289
        temp_["target"]["components"][0]["spectral"]["ebl_abs"]["filename"] = str(
290
            temp_["target"]["components"][0]["spectral"]["ebl_abs"]["filename"]
291
        )
292
    else:
293
        temp_["target"]["components"][0]["spectral"].pop("ebl_abs", None)
294
295
    yaml_ = yaml.dump(
296
        temp_,
297
        sort_keys=False,
298
        indent=4,
299
        width=80,
300
        default_flow_style=None,
301
    )
302
303
    output_file.write_text(yaml_)
304
305
306
# Combine everything!
307
class AsgardpyConfig(BaseConfig):
308
    """
309
    Asgardpy analysis configuration, based on Gammapy Analysis Config.
310
    """
311
312
    general: GeneralConfig = GeneralConfig()
313
314
    target: Target = Target()
315
316
    dataset3d: Dataset3DConfig = Dataset3DConfig()
317
    dataset1d: Dataset1DConfig = Dataset1DConfig()
318
319
    fit_params: FitConfig = FitConfig()
320
    flux_points_params: FluxPointsConfig = FluxPointsConfig()
321
322
    def __str__(self):
323
        """
324
        Display settings in pretty YAML format.
325
        """
326
        info = self.__class__.__name__ + "\n\n\t"
327
        data = self.to_yaml()
328
        data = data.replace("\n", "\n\t")
329
        info += data
330
        return info.expandtabs(tabsize=4)
331
332
    @classmethod
333
    def read(cls, path):
334
        """
335
        Reads from YAML file.
336
        """
337
        config = read_yaml(path)
338
        return AsgardpyConfig(**config)
339
340
    @classmethod
341
    def from_yaml(cls, config_str):
342
        """
343
        Create from YAML string.
344
        """
345
        settings = yaml.safe_load(config_str)
346
        return AsgardpyConfig(**settings)
347
348
    def write(self, path, overwrite=False):
349
        """
350
        Write to YAML file.
351
        """
352
        path = make_path(path)
353
        if path.exists() and not overwrite:
354
            raise OSError(f"File exists already: {path}")
355
        path.write_text(self.to_yaml())
356
357
    def to_yaml(self):
358
        """
359
        Convert to YAML string.
360
        """
361
        # Here using `dict()` instead of `json()` would be more natural.
362
        # We should change this once pydantic adds support for custom encoders
363
        # to `dict()`. See https://github.com/samuelcolvin/pydantic/issues/1043
364
        data = json.loads(self.model_dump_json())
365
        return yaml.dump(data, sort_keys=False, indent=4, width=80, default_flow_style=None)
366
367
    def set_logging(self):
368
        """
369
        Set logging config.
370
        Calls ``logging.basicConfig``, i.e. adjusts global logging state.
371
        """
372
        self.general.log.level = self.general.log.level.upper()
373
        logging.basicConfig(**self.general.log.model_dump())
374
        log.info("Setting logging config: %s", self.general.log.model_dump())
375
376
    def update(self, config=None, merge_recursive=False):
377
        """
378
        Update config with provided settings.
379
        Parameters
380
        ----------
381
        config : string dict or `AsgardpyConfig` object
382
            The other configuration settings provided in dict() syntax.
383
        merge_recursive : bool
384
            Perform a recursive merge from the other config onto the parent config.
385
386
        Returns
387
        -------
388
        config : `AsgardpyConfig` object
389
            Updated config object.
390
        """
391
        other = check_config(config)
392
393
        # Special case of when only updating target model parameters from a
394
        # separate file, where the name of the source is not provided.
395
        if other.target.components[0].name == "":
396
            merge_recursive = True
397
398
        if merge_recursive:
399
            config_new = recursive_merge_dicts(
400
                self.model_dump(exclude_defaults=True), other.model_dump(exclude_defaults=True)
401
            )
402
        else:
403
            config_new = deep_update(
404
                self.model_dump(exclude_defaults=True), other.model_dump(exclude_defaults=True)
405
            )
406
        return AsgardpyConfig(**config_new)
407