Passed
Pull Request — main (#179)
by Chaitanya
01:55
created

asgardpy.config.generator   B

Complexity

Total Complexity 43

Size/Duplication

Total Lines 365
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 181
dl 0
loc 365
rs 8.96
c 0
b 0
f 0
wmc 43

6 Functions

Rating   Name   Duplication   Size   Complexity  
A get_model_template() 0 12 3
A all_model_templates() 0 12 2
C recursive_merge_dicts() 0 52 9
A deep_update() 0 12 3
B write_asgardpy_model_to_file() 0 45 6
B gammapy_model_to_asgardpy_model_config() 0 42 7

7 Methods

Rating   Name   Duplication   Size   Complexity  
A AsgardpyConfig.__str__() 0 9 1
A AsgardpyConfig.read() 0 7 1
A AsgardpyConfig.update() 0 36 5
A AsgardpyConfig.to_yaml() 0 9 1
A AsgardpyConfig.write() 0 8 3
A AsgardpyConfig.set_logging() 0 8 1
A AsgardpyConfig.from_yaml() 0 7 1

How to fix   Complexity   

Complexity

Complex classes like asgardpy.config.generator often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Main AsgardpyConfig Generator Module
3
"""
4
5
import json
6
import logging
7
import os
8
from collections.abc import Mapping
9
from enum import Enum
10
from pathlib import Path
11
12
import numpy as np
13
import yaml
14
from gammapy.modeling.models import CompoundSpectralModel, Models
15
from gammapy.utils.scripts import make_path, read_yaml
16
17
from asgardpy.analysis.step_base import AnalysisStepEnum
18
from asgardpy.base import BaseConfig, PathType
19
from asgardpy.data import (
20
    Dataset1DConfig,
21
    Dataset3DConfig,
22
    FitConfig,
23
    FluxPointsConfig,
24
    Target,
25
)
26
27
__all__ = [
28
    "all_model_templates",
29
    "AsgardpyConfig",
30
    "GeneralConfig",
31
    "gammapy_model_to_asgardpy_model_config",
32
    "get_model_template",
33
    "recursive_merge_dicts",
34
    "write_asgardpy_model_to_file",
35
]
36
37
CONFIG_PATH = Path(__file__).resolve().parent
38
39
log = logging.getLogger(__name__)
40
41
42
# Other general config params
43
class LogConfig(BaseConfig):
44
    """Config section for main logging information."""
45
46
    level: str = "info"
47
    filename: str = ""
48
    filemode: str = "w"
49
    format: str = ""
50
    datefmt: str = ""
51
52
53
class ParallelBackendEnum(str, Enum):
54
    """Config section for list of parallel processing backend methods."""
55
56
    multi = "multiprocessing"
57
    ray = "ray"
58
59
60
class GeneralConfig(BaseConfig):
61
    """Config section for general information for running AsgardpyAnalysis."""
62
63
    log: LogConfig = LogConfig()
64
    outdir: PathType = "None"
65
    n_jobs: int = 1
66
    parallel_backend: ParallelBackendEnum = ParallelBackendEnum.multi
67
    steps: list[AnalysisStepEnum] = []
68
    overwrite: bool = True
69
    stacked_dataset: bool = False
70
71
72
def all_model_templates():
73
    """
74
    Collect all Template Models provided in Asgardpy, and their small tag names.
75
    """
76
    template_files = sorted(list(CONFIG_PATH.glob("model_templates/model_template*yaml")))
77
78
    all_tags = []
79
    for file in template_files:
80
        all_tags.append(file.name.split("_")[-1].split(".")[0])
81
    all_tags = np.array(all_tags)
82
83
    return all_tags, template_files
84
85
86
def get_model_template(spec_model_tag):
87
    """
88
    Read a particular template model yaml filename to create an AsgardpyConfig
89
    object.
90
    """
91
    all_tags, template_files = all_model_templates()
92
    new_model_file = None
93
94
    for file, tag in zip(template_files, all_tags, strict=True):
95
        if spec_model_tag == tag:
96
            new_model_file = file
97
    return new_model_file
98
99
100
def recursive_merge_dicts(base_config, extra_config):
101
    """
102
    recursively merge two dictionaries.
103
    Entries in extra_config override entries in base_config. The built-in
104
    update function cannot be used for hierarchical dicts.
105
106
    Also for the case when there is a list of dicts involved, one has to be
107
    more careful. The extra_config may have longer list of dicts as compared
108
    with the base_config, in which case, the extra items are simply added to
109
    the merged final list.
110
111
    Combined here are 2 options from SO.
112
113
    See:
114
    http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356
115
    and also
116
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/18394648#18394648
117
118
    Parameters
119
    ----------
120
    base_config : dict
121
        dictionary to be merged
122
    extra_config : dict
123
        dictionary to be merged
124
    Returns
125
    -------
126
    final_config : dict
127
        merged dict
128
    """
129
    final_config = base_config.copy()
130
131
    for key, value in extra_config.items():
132
        if key in final_config and isinstance(final_config[key], list):
133
            new_config = []
134
135
            for key_, value_ in zip(final_config[key], value, strict=False):
136
                key_ = recursive_merge_dicts(key_ or {}, value_)
137
                new_config.append(key_)
138
139
            # For example moving from a smaller list of model parameters to a
140
            # longer list.
141
            if len(final_config[key]) < len(extra_config[key]):
142
                for value_ in value[len(final_config[key]) :]:
143
                    new_config.append(value_)
144
            final_config[key] = new_config
145
146
        elif key in final_config and isinstance(final_config[key], dict):
147
            final_config[key] = recursive_merge_dicts(final_config.get(key) or {}, value)
148
        else:
149
            final_config[key] = value
150
151
    return final_config
152
153
154
def deep_update(d, u):
155
    """
156
    Recursively update a nested dictionary.
157
158
    Just like in Gammapy, taken from: https://stackoverflow.com/a/3233356/19802442
159
    """
160
    for k, v in u.items():
161
        if isinstance(v, Mapping):
162
            d[k] = deep_update(d.get(k, {}), v)
163
        else:
164
            d[k] = v
165
    return d
166
167
168
def gammapy_model_to_asgardpy_model_config(gammapy_model, asgardpy_config_file=None, recursive_merge=True):
169
    """
170
    Read the Gammapy Models object and save it as AsgardpyConfig object.
171
172
    The gammapy_model object may be a YAML config filename/path/object or a
173
    Gammapy Models object itself.
174
175
    Return
176
    ------
177
    asgardpy_config: `asgardpy.config.generator.AsgardpyConfig`
178
        Updated AsgardpyConfig object
179
    """
180
181
    if isinstance(gammapy_model, Models):
182
        models_gpy = gammapy_model
183
    else:
184
        try:
185
            models_gpy = Models.read(gammapy_model)
186
        except KeyError:
187
            log.error("%s File cannot be read by Gammapy Models", gammapy_model)
188
            return None
189
190
    if not asgardpy_config_file:
191
        asgardpy_config = AsgardpyConfig()  # Default object
192
    elif isinstance(asgardpy_config_file, str):  # File path
193
        asgardpy_config = AsgardpyConfig.read(asgardpy_config_file)
194
    elif isinstance(asgardpy_config_file, AsgardpyConfig):
195
        asgardpy_config = asgardpy_config_file
196
197
    models_gpy_dict = models_gpy.to_dict()
198
    asgardpy_config_target_dict = asgardpy_config.model_dump()["target"]
199
200
    if recursive_merge:
201
        temp_target_dict = recursive_merge_dicts(asgardpy_config_target_dict, models_gpy_dict)
202
    else:
203
        # Use when there are nans present in the other config file, which are
204
        # the defaults in Gammapy, but NOT in Asgardpy.
205
        # E.g. test data Fermi-3fhl-crab model file
206
        temp_target_dict = deep_update(asgardpy_config_target_dict, models_gpy_dict)
207
    asgardpy_config.target = temp_target_dict
208
209
    return asgardpy_config
210
211
212
def write_asgardpy_model_to_file(gammapy_model, output_file=None, recursive_merge=True):
213
    """
214
    Read the Gammapy Models object and save it as AsgardpyConfig YAML file
215
    containing only the Model parameters, similar to the model templates
216
    available.
217
    """
218
    if not isinstance(gammapy_model, Models):
219
        try:
220
            gammapy_model = Models(gammapy_model)
221
        except KeyError:
222
            log.error("%s Object cannot be read as Gammapy Models", gammapy_model)
223
            return None
224
225
    asgardpy_config = gammapy_model_to_asgardpy_model_config(
226
        gammapy_model=gammapy_model,
227
        asgardpy_config_file=None,
228
        recursive_merge=recursive_merge,
229
    )
230
231
    if not output_file:
232
        if isinstance(gammapy_model[0].spectral_model, CompoundSpectralModel):
233
            model_tag = gammapy_model[0].spectral_model.model1.tag[1]
234
        else:
235
            model_tag = gammapy_model[0].spectral_model.tag[1]
236
237
        output_file = CONFIG_PATH / f"model_templates/model_template_{model_tag}.yaml"
238
        os.path.expandvars(output_file)
239
    else:
240
        if not isinstance(output_file, Path):
241
            output_file = Path(os.path.expandvars(output_file))
242
243
    temp_ = asgardpy_config.model_dump(exclude_defaults=True)
244
    temp_["target"].pop("models_file", None)
245
    temp_["target"]["components"][0]["spectral"].pop("ebl_abs", None)
246
    temp_["target"]["components"][0].pop("name", None)
247
248
    yaml_ = yaml.dump(
249
        temp_,
250
        sort_keys=False,
251
        indent=4,
252
        width=80,
253
        default_flow_style=None,
254
    )
255
256
    output_file.write_text(yaml_)
257
258
259
# Combine everything!
260
class AsgardpyConfig(BaseConfig):
261
    """
262
    Asgardpy analysis configuration, based on Gammapy Analysis Config.
263
    """
264
265
    general: GeneralConfig = GeneralConfig()
266
267
    target: Target = Target()
268
269
    dataset3d: Dataset3DConfig = Dataset3DConfig()
270
    dataset1d: Dataset1DConfig = Dataset1DConfig()
271
272
    fit_params: FitConfig = FitConfig()
273
    flux_points_params: FluxPointsConfig = FluxPointsConfig()
274
275
    def __str__(self):
276
        """
277
        Display settings in pretty YAML format.
278
        """
279
        info = self.__class__.__name__ + "\n\n\t"
280
        data = self.to_yaml()
281
        data = data.replace("\n", "\n\t")
282
        info += data
283
        return info.expandtabs(tabsize=4)
284
285
    @classmethod
286
    def read(cls, path):
287
        """
288
        Reads from YAML file.
289
        """
290
        config = read_yaml(path)
291
        return AsgardpyConfig(**config)
292
293
    @classmethod
294
    def from_yaml(cls, config_str):
295
        """
296
        Create from YAML string.
297
        """
298
        settings = yaml.safe_load(config_str)
299
        return AsgardpyConfig(**settings)
300
301
    def write(self, path, overwrite=False):
302
        """
303
        Write to YAML file.
304
        """
305
        path = make_path(path)
306
        if path.exists() and not overwrite:
307
            raise OSError(f"File exists already: {path}")
308
        path.write_text(self.to_yaml())
309
310
    def to_yaml(self):
311
        """
312
        Convert to YAML string.
313
        """
314
        # Here using `dict()` instead of `json()` would be more natural.
315
        # We should change this once pydantic adds support for custom encoders
316
        # to `dict()`. See https://github.com/samuelcolvin/pydantic/issues/1043
317
        data = json.loads(self.model_dump_json())
318
        return yaml.dump(data, sort_keys=False, indent=4, width=80, default_flow_style=None)
319
320
    def set_logging(self):
321
        """
322
        Set logging config.
323
        Calls ``logging.basicConfig``, i.e. adjusts global logging state.
324
        """
325
        self.general.log.level = self.general.log.level.upper()
326
        logging.basicConfig(**self.general.log.model_dump())
327
        log.info("Setting logging config: %s", self.general.log.model_dump())
328
329
    def update(self, config=None, merge_recursive=False):
330
        """
331
        Update config with provided settings.
332
        Parameters
333
        ----------
334
        config : string dict or `AsgardpyConfig` object
335
            The other configuration settings provided in dict() syntax.
336
        merge_recursive : bool
337
            Perform a recursive merge from the other config onto the parent config.
338
339
        Returns
340
        -------
341
        config : `AsgardpyConfig` object
342
            Updated config object.
343
        """
344
        if isinstance(config, str):
345
            other = AsgardpyConfig.from_yaml(config)
346
        elif isinstance(config, AsgardpyConfig):
347
            other = config
348
        else:
349
            raise TypeError(f"Invalid type: {config}")
350
351
        # Special case of when only updating target model parameters from a
352
        # separate file, where the name of the source is not provided.
353
        if other.target.components[0].name == "":
354
            merge_recursive = True
355
356
        if merge_recursive:
357
            config_new = recursive_merge_dicts(
358
                self.model_dump(exclude_defaults=True), other.model_dump(exclude_defaults=True)
359
            )
360
        else:
361
            config_new = deep_update(
362
                self.model_dump(exclude_defaults=True), other.model_dump(exclude_defaults=True)
363
            )
364
        return AsgardpyConfig(**config_new)
365