gammapy.modeling.models.core   F
last analyzed

Complexity

Total Complexity 186

Size/Duplication

Total Lines 1083
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 551
dl 0
loc 1083
rs 2
c 0
b 0
f 0
wmc 186

2 Functions

Rating   Name   Duplication   Size   Complexity  
A _get_model_class_from_dict() 0 18 5
A _set_link() 0 11 4

64 Methods

Rating   Name   Duplication   Size   Complexity  
A ModelBase.parameters() 0 5 1
A ModelBase.__init__() 0 15 3
A ModelBase._check_covariance() 0 3 2
A ModelBase.from_parameters() 0 17 2
A ModelBase.__getattribute__() 0 7 2
A ModelBase.covariance() 0 10 2
A ModelBase.__init_subclass__() 0 4 1
A ModelBase.type() 0 3 1
A ModelBase.copy() 0 4 1
A DatasetModels.__add__() 0 9 4
A DatasetModels.freeze() 0 11 2
A DatasetModels.read() 0 6 1
A DatasetModels.set_parameters_bounds() 0 30 4
A DatasetModels.select_mask() 0 33 5
A DatasetModels.index() 0 9 4
A DatasetModels.names() 0 3 1
A DatasetModels.to_yaml() 0 5 1
A DatasetModels.frozen() 0 4 1
A DatasetModels.reassign() 0 14 1
B DatasetModels.to_dict() 0 23 6
A DatasetModels.__str__() 0 11 2
A DatasetModels.read_covariance() 0 21 1
A DatasetModels.positions() 0 14 3
D DatasetModels.selection_mask() 0 63 13
A DatasetModels.wcs_geom() 0 8 2
A DatasetModels._ipython_key_completions_() 0 2 1
A DatasetModels.plot_positions() 0 33 4
A DatasetModels.unfreeze() 0 11 2
A DatasetModels.from_yaml() 0 5 1
A DatasetModels.__getitem__() 0 5 3
A DatasetModels.copy() 0 21 2
A DatasetModels.covariance() 0 8 2
A DatasetModels.parameters_unique_names() 0 11 3
A DatasetModels.to_regions() 0 19 3
A DatasetModels.update_link_label() 0 12 5
A DatasetModels.plot_regions() 0 26 1
A DatasetModels.__len__() 0 2 1
A DatasetModels.write_covariance() 0 20 2
B DatasetModels.from_dict() 0 36 8
B DatasetModels.__init__() 0 20 7
A DatasetModels.restore_status() 0 14 1
A DatasetModels.select() 0 33 1
A DatasetModels.select_region() 0 26 3
A DatasetModels._check_covariance() 0 4 2
A DatasetModels.to_parameters_table() 0 7 1
B DatasetModels.write() 0 43 6
A DatasetModels.parameters() 0 3 1
A DatasetModels.update_parameters_from_table() 0 5 2
A DatasetModels.to_template_sky_model() 0 32 3
C ModelBase.to_dict() 0 35 11
A ModelBase.unfreeze() 0 4 2
A Models.insert() 0 5 2
A restore_models_status.__init__() 0 6 1
A restore_models_status.__exit__() 0 6 3
A ModelBase.__str__() 0 5 2
A Model.from_dict() 0 6 1
A ModelBase.frozen() 0 4 1
A Models.__delitem__() 0 2 1
A ModelBase.freeze() 0 3 1
B ModelBase.from_dict() 0 37 6
A Models.__setitem__() 0 7 2
A restore_models_status.__enter__() 0 2 1
B ModelBase.reassign() 0 36 6
A Model.create() 0 20 2

How to fix   Complexity   

Complexity

Complex classes like gammapy.modeling.models.core often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import collections.abc
3
import copy
4
import logging
5
from os.path import split
6
import numpy as np
7
import astropy.units as u
8
from astropy.coordinates import SkyCoord
9
from astropy.table import Table
10
import matplotlib.pyplot as plt
11
import yaml
12
from gammapy.maps import Map, RegionGeom
13
from gammapy.modeling import Covariance, Parameter, Parameters
14
from gammapy.modeling.covariance import copy_covariance
15
from gammapy.utils.scripts import make_name, make_path
16
17
__all__ = ["Model", "Models", "DatasetModels", "ModelBase"]
18
19
20
log = logging.getLogger(__name__)
21
22
23
def _set_link(shared_register, model):
24
    for param in model.parameters:
25
        name = param.name
26
        link_label = param._link_label_io
27
        if link_label is not None:
28
            if link_label in shared_register:
29
                new_param = shared_register[link_label]
30
                setattr(model, name, new_param)
31
            else:
32
                shared_register[link_label] = param
33
    return shared_register
34
35
36
def _get_model_class_from_dict(data):
37
    """get a model class from a dict"""
38
    from . import (
39
        MODEL_REGISTRY,
40
        SPATIAL_MODEL_REGISTRY,
41
        SPECTRAL_MODEL_REGISTRY,
42
        TEMPORAL_MODEL_REGISTRY,
43
    )
44
45
    if "type" in data:
46
        cls = MODEL_REGISTRY.get_cls(data["type"])
47
    elif "spatial" in data:
48
        cls = SPATIAL_MODEL_REGISTRY.get_cls(data["spatial"]["type"])
49
    elif "spectral" in data:
50
        cls = SPECTRAL_MODEL_REGISTRY.get_cls(data["spectral"]["type"])
51
    elif "temporal" in data:
52
        cls = TEMPORAL_MODEL_REGISTRY.get_cls(data["temporal"]["type"])
53
    return cls
0 ignored issues
show
introduced by
The variable cls does not seem to be defined for all execution paths.
Loading history...
54
55
56
class ModelBase:
57
    """Model base class."""
58
59
    _type = None
60
61
    def __init__(self, **kwargs):
62
        # Copy default parameters from the class to the instance
63
        default_parameters = self.default_parameters.copy()
64
65
        for par in default_parameters:
66
            value = kwargs.get(par.name, par)
67
68
            if not isinstance(value, Parameter):
69
                par.quantity = u.Quantity(value)
70
            else:
71
                par = value
72
73
            setattr(self, par.name, par)
74
75
        self._covariance = Covariance(self.parameters)
76
77
    def __getattribute__(self, name):
78
        value = object.__getattribute__(self, name)
79
80
        if isinstance(value, Parameter):
81
            return value.__get__(self, None)
82
83
        return value
84
85
    @property
86
    def type(self):
87
        return self._type
88
89
    def __init_subclass__(cls, **kwargs):
90
        # Add parameters list on the model sub-class (not instances)
91
        cls.default_parameters = Parameters(
92
            [_ for _ in cls.__dict__.values() if isinstance(_, Parameter)]
93
        )
94
95
    @classmethod
96
    def from_parameters(cls, parameters, **kwargs):
97
        """Create model from parameter list
98
99
        Parameters
100
        ----------
101
        parameters : `Parameters`
102
            Parameters for init
103
104
        Returns
105
        -------
106
        model : `Model`
107
            Model instance
108
        """
109
        for par in parameters:
110
            kwargs[par.name] = par
111
        return cls(**kwargs)
112
113
    def _check_covariance(self):
114
        if not self.parameters == self._covariance.parameters:
115
            self._covariance = Covariance(self.parameters)
116
117
    @property
118
    def covariance(self):
119
        self._check_covariance()
120
        for par in self.parameters:
121
            pars = Parameters([par])
122
            error = np.nan_to_num(par.error**2, nan=1)
123
            covar = Covariance(pars, data=[[error]])
124
            self._covariance.set_subcovariance(covar)
125
126
        return self._covariance
127
128
    @covariance.setter
129
    def covariance(self, covariance):
130
        self._check_covariance()
131
        self._covariance.data = covariance
132
133
        for par in self.parameters:
134
            pars = Parameters([par])
135
            variance = self._covariance.get_subcovariance(pars)
136
            par.error = np.sqrt(variance)
137
138
    @property
139
    def parameters(self):
140
        """Parameters (`~gammapy.modeling.Parameters`)"""
141
        return Parameters(
142
            [getattr(self, name) for name in self.default_parameters.names]
143
        )
144
145
    @copy_covariance
146
    def copy(self, **kwargs):
147
        """A deep copy."""
148
        return copy.deepcopy(self)
149
150
    def to_dict(self, full_output=False):
151
        """Create dict for YAML serialisation"""
152
        tag = self.tag[0] if isinstance(self.tag, list) else self.tag
153
        params = self.parameters.to_dict()
154
155
        if not full_output:
156
            for par, par_default in zip(params, self.default_parameters):
157
                init = par_default.to_dict()
158
                for item in [
159
                    "min",
160
                    "max",
161
                    "error",
162
                    "interp",
163
                    "scale_method",
164
                    "is_norm",
165
                ]:
166
                    default = init[item]
167
168
                    if par[item] == default or (
169
                        np.isnan(par[item]) and np.isnan(default)
170
                    ):
171
                        del par[item]
172
173
                if not par["frozen"]:
174
                    del par["frozen"]
175
176
                if init["unit"] == "":
177
                    del par["unit"]
178
179
        data = {"type": tag, "parameters": params}
180
181
        if self.type is None:
182
            return data
183
        else:
184
            return {self.type: data}
185
186
    @classmethod
187
    def from_dict(cls, data):
188
        kwargs = {}
189
190
        par_data = []
191
        key0 = next(iter(data))
192
193
        if key0 in ["spatial", "temporal", "spectral"]:
194
            data = data[key0]
195
196
        if data["type"] not in cls.tag:
197
            raise ValueError(
198
                f"Invalid model type {data['type']} for class {cls.__name__}"
199
            )
200
201
        input_names = [_["name"] for _ in data["parameters"]]
202
203
        for par in cls.default_parameters:
204
            par_dict = par.to_dict()
205
            try:
206
                index = input_names.index(par_dict["name"])
207
                par_dict.update(data["parameters"][index])
208
            except ValueError:
209
                log.warning(
210
                    f"Parameter '{par_dict['name']}' not defined in YAML file."
211
                    f" Using default value: {par_dict['value']} {par_dict['unit']}"
212
                )
213
            par_data.append(par_dict)
214
215
        parameters = Parameters.from_dict(par_data)
216
217
        # TODO: this is a special case for spatial models, maybe better move to
218
        #  `SpatialModel` base class
219
        if "frame" in data:
220
            kwargs["frame"] = data["frame"]
221
222
        return cls.from_parameters(parameters, **kwargs)
223
224
    def __str__(self):
225
        string = f"{self.__class__.__name__}\n"
226
        if len(self.parameters) > 0:
227
            string += f"\n{self.parameters.to_table()}"
228
        return string
229
230
    @property
231
    def frozen(self):
232
        """Frozen status of a model, True if all parameters are frozen"""
233
        return np.all([p.frozen for p in self.parameters])
234
235
    def freeze(self):
236
        """Freeze all parameters"""
237
        self.parameters.freeze_all()
238
239
    def unfreeze(self):
240
        """Restore parameters frozen status to default"""
241
        for p, default in zip(self.parameters, self.default_parameters):
242
            p.frozen = default.frozen
243
244
    def reassign(self, datasets_names, new_datasets_names):
245
        """Reassign a model from one dataset to another
246
247
        Parameters
248
        ----------
249
        datasets_names : str or list
250
            Name of the datasets where the model is currently defined
251
        new_datasets_names : str or list
252
            Name of the datasets where the model should be defined instead.
253
            If multiple names are given the two list must have the save length,
254
            as the reassignment is element-wise.
255
256
        Returns
257
        -------
258
        model : `Model`
259
            Reassigned model.
260
261
        """
262
        model = self.copy(name=self.name)
263
264
        if not isinstance(datasets_names, list):
265
            datasets_names = [datasets_names]
266
267
        if not isinstance(new_datasets_names, list):
268
            new_datasets_names = [new_datasets_names]
269
270
        if isinstance(model.datasets_names, str):
271
            model.datasets_names = [model.datasets_names]
272
273
        if getattr(model, "datasets_names", None):
274
            for name, name_new in zip(datasets_names, new_datasets_names):
275
                model.datasets_names = [
276
                    _.replace(name, name_new) for _ in model.datasets_names
277
                ]
278
279
        return model
280
281
282
class Model:
283
    """Model class that contains only methods to create a model listed in the registries."""
284
285
    @staticmethod
286
    def create(tag, model_type=None, *args, **kwargs):
287
        """Create a model instance.
288
289
        Examples
290
        --------
291
        >>> from gammapy.modeling.models import Model
292
        >>> spectral_model = Model.create(
293
                    "pl-2", model_type="spectral", amplitude="1e-10 cm-2 s-1", index=3
294
                )
295
        >>> type(spectral_model)
296
        <class 'gammapy.modeling.models.spectral.PowerLaw2SpectralModel'>
297
        """
298
299
        data = {"type": tag}
300
        if model_type is not None:
301
            data = {model_type: data}
302
303
        cls = _get_model_class_from_dict(data)
304
        return cls(*args, **kwargs)
305
306
    @staticmethod
307
    def from_dict(data):
308
        """Create a model instance from a dict"""
309
310
        cls = _get_model_class_from_dict(data)
311
        return cls.from_dict(data)
312
313
314
class DatasetModels(collections.abc.Sequence):
315
    """Immutable models container
316
317
    Parameters
318
    ----------
319
    models : `SkyModel`, list of `SkyModel` or `Models`
320
        Sky models
321
    """
322
323
    def __init__(self, models=None):
324
        if models is None:
325
            models = []
326
327
        if isinstance(models, (Models, DatasetModels)):
328
            models = models._models
329
        elif isinstance(models, ModelBase):
330
            models = [models]
331
        elif not isinstance(models, list):
332
            raise TypeError(f"Invalid type: {models!r}")
333
334
        unique_names = []
335
        for model in models:
336
            if model.name in unique_names:
337
                raise (ValueError("Model names must be unique"))
338
            unique_names.append(model.name)
339
340
        self._models = models
341
        self._covar_file = None
342
        self._covariance = Covariance(self.parameters)
343
344
    def _check_covariance(self):
345
        if not self.parameters == self._covariance.parameters:
346
            self._covariance = Covariance.from_stack(
347
                [model.covariance for model in self._models]
348
            )
349
350
    @property
351
    def covariance(self):
352
        self._check_covariance()
353
354
        for model in self._models:
355
            self._covariance.set_subcovariance(model.covariance)
356
357
        return self._covariance
358
359
    @covariance.setter
360
    def covariance(self, covariance):
361
        self._check_covariance()
362
        self._covariance.data = covariance
363
364
        for model in self._models:
365
            subcovar = self._covariance.get_subcovariance(model.covariance.parameters)
366
            model.covariance = subcovar
367
368
    @property
369
    def parameters(self):
370
        return Parameters.from_stack([_.parameters for _ in self._models])
371
372
    @property
373
    def parameters_unique_names(self):
374
        """List of unique parameter names as model_name.par_type.par_name"""
375
        names = []
376
        for model in self:
377
            for par in model.parameters:
378
                components = [model.name, par.type, par.name]
379
                name = ".".join(components)
380
                names.append(name)
381
382
        return names
383
384
    @property
385
    def names(self):
386
        return [m.name for m in self._models]
387
388
    @classmethod
389
    def read(cls, filename):
390
        """Read from YAML file."""
391
        yaml_str = make_path(filename).read_text()
392
        path, filename = split(filename)
393
        return cls.from_yaml(yaml_str, path=path)
394
395
    @classmethod
396
    def from_yaml(cls, yaml_str, path=""):
397
        """Create from YAML string."""
398
        data = yaml.safe_load(yaml_str)
399
        return cls.from_dict(data, path=path)
400
401
    @classmethod
402
    def from_dict(cls, data, path=""):
403
        """Create from dict."""
404
        from . import MODEL_REGISTRY, SkyModel
405
406
        models = []
407
408
        for component in data["components"]:
409
            model_cls = MODEL_REGISTRY.get_cls(component["type"])
410
            model = model_cls.from_dict(component)
411
            models.append(model)
412
413
        models = cls(models)
414
415
        if "covariance" in data:
416
            filename = data["covariance"]
417
            path = make_path(path)
418
            if not (path / filename).exists():
419
                path, filename = split(filename)
420
421
            models.read_covariance(path, filename, format="ascii.fixed_width")
422
423
        shared_register = {}
424
        for model in models:
425
            if isinstance(model, SkyModel):
426
                submodels = [
427
                    model.spectral_model,
428
                    model.spatial_model,
429
                    model.temporal_model,
430
                ]
431
                for submodel in submodels:
432
                    if submodel is not None:
433
                        shared_register = _set_link(shared_register, submodel)
434
            else:
435
                shared_register = _set_link(shared_register, model)
436
        return models
437
438
    def write(
439
        self,
440
        path,
441
        overwrite=False,
442
        full_output=False,
443
        overwrite_templates=False,
444
        write_covariance=True,
445
    ):
446
        """Write to YAML file.
447
448
        Parameters
449
        ----------
450
        path : `pathlib.Path` or str
451
            path to write files
452
        overwrite : bool
453
            overwrite YAML files
454
        full_output : bool
455
            Store full parameter output.
456
        overwrite_templates : bool
457
            overwrite templates FITS files
458
        write_covariance : bool
459
            save covariance or not
460
        """
461
        base_path, _ = split(path)
462
        path = make_path(path)
463
        base_path = make_path(base_path)
464
465
        if path.exists() and not overwrite:
466
            raise IOError(f"File exists already: {path}")
467
468
        if (
469
            write_covariance
470
            and self.covariance is not None
471
            and len(self.parameters) != 0
472
        ):
473
            filecovar = path.stem + "_covariance.dat"
474
            kwargs = dict(
475
                format="ascii.fixed_width", delimiter="|", overwrite=overwrite
476
            )
477
            self.write_covariance(base_path / filecovar, **kwargs)
478
            self._covar_file = filecovar
479
480
        path.write_text(self.to_yaml(full_output, overwrite_templates))
481
482
    def to_yaml(self, full_output=False, overwrite_templates=False):
483
        """Convert to YAML string."""
484
        data = self.to_dict(full_output, overwrite_templates)
485
        return yaml.dump(
486
            data, sort_keys=False, indent=4, width=80, default_flow_style=False
487
        )
488
489
    def update_link_label(self):
490
        """update linked parameters labels used for serialization and print"""
491
        params_list = []
492
        params_shared = []
493
        for param in self.parameters:
494
            if param not in params_list:
495
                params_list.append(param)
496
                params_list.append(param)
497
            elif param not in params_shared:
498
                params_shared.append(param)
499
        for param in params_shared:
500
            param._link_label_io = param.name + "@" + make_name()
501
502
    def to_dict(self, full_output=False, overwrite_templates=False):
503
        """Convert to dict."""
504
505
        self.update_link_label()
506
507
        models_data = []
508
        for model in self._models:
509
            model_data = model.to_dict(full_output)
510
            models_data.append(model_data)
511
            if (
512
                hasattr(model, "spatial_model")
513
                and model.spatial_model is not None
514
                and "template" in model.spatial_model.tag
515
            ):
516
                model.spatial_model.write(overwrite=overwrite_templates)
517
518
        if self._covar_file is not None:
519
            return {
520
                "components": models_data,
521
                "covariance": str(self._covar_file),
522
            }
523
        else:
524
            return {"components": models_data}
525
526
    def to_parameters_table(self):
527
        """Convert Models parameters to an astropy Table."""
528
        table = self.parameters.to_table()
529
        # Warning: splitting of parameters will break is source name has a "." in its name.
530
        model_name = [name.split(".")[0] for name in self.parameters_unique_names]
531
        table.add_column(model_name, name="model", index=0)
532
        return table
533
534
    def update_parameters_from_table(self, t):
535
        """Update Models from an astropy Table."""
536
        parameters_dict = [dict(zip(t.colnames, row)) for row in t]
537
        for k, data in enumerate(parameters_dict):
538
            self.parameters[k].update_from_dict(data)
539
540
    def read_covariance(self, path, filename="_covariance.dat", **kwargs):
541
        """Read covariance data from file
542
543
        Parameters
544
        ----------
545
        path : str or `Path`
546
            Base path
547
        filename : str
548
            Filename
549
        **kwargs : dict
550
            Keyword arguments passed to `~astropy.table.Table.read`
551
552
        """
553
        path = make_path(path)
554
        filepath = str(path / filename)
555
        t = Table.read(filepath, **kwargs)
556
        t.remove_column("Parameters")
557
        arr = np.array(t)
558
        data = arr.view(float).reshape(arr.shape + (-1,))
559
        self.covariance = data
560
        self._covar_file = filename
561
562
    def write_covariance(self, filename, **kwargs):
563
        """Write covariance to file
564
565
        Parameters
566
        ----------
567
        filename : str
568
            Filename
569
        **kwargs : dict
570
            Keyword arguments passed to `~astropy.table.Table.write`
571
572
        """
573
        names = self.parameters_unique_names
574
        table = Table()
575
        table["Parameters"] = names
576
577
        for idx, name in enumerate(names):
578
            values = self.covariance.data[idx]
579
            table[str(idx)] = values
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable str does not seem to be defined.
Loading history...
580
581
        table.write(make_path(filename), **kwargs)
582
583
    def __str__(self):
584
585
        self.update_link_label()
586
587
        str_ = f"{self.__class__.__name__}\n\n"
588
589
        for idx, model in enumerate(self):
590
            str_ += f"Component {idx}: "
591
            str_ += str(model)
592
593
        return str_.expandtabs(tabsize=2)
594
595
    def __add__(self, other):
596
        if isinstance(other, (Models, list)):
597
            return Models([*self, *other])
598
        elif isinstance(other, ModelBase):
599
            if other.name in self.names:
600
                raise (ValueError("Model names must be unique"))
601
            return Models([*self, other])
602
        else:
603
            raise TypeError(f"Invalid type: {other!r}")
604
605
    def __getitem__(self, key):
606
        if isinstance(key, np.ndarray) and key.dtype == bool:
607
            return self.__class__(list(np.array(self._models)[key]))
608
        else:
609
            return self._models[self.index(key)]
610
611
    def index(self, key):
612
        if isinstance(key, (int, slice)):
613
            return key
614
        elif isinstance(key, str):
615
            return self.names.index(key)
616
        elif isinstance(key, ModelBase):
617
            return self._models.index(key)
618
        else:
619
            raise TypeError(f"Invalid type: {type(key)!r}")
620
621
    def __len__(self):
622
        return len(self._models)
623
624
    def _ipython_key_completions_(self):
625
        return self.names
626
627
    @copy_covariance
628
    def copy(self, copy_data=False):
629
        """A deep copy.
630
631
        Parameters
632
        ----------
633
        copy_data : bool
634
            Whether to copy data attached to template models
635
636
        Returns
637
        -------
638
        models: `Models`
639
            Copied models.
640
        """
641
        models = []
642
643
        for model in self:
644
            model_copy = model.copy(name=model.name, copy_data=copy_data)
645
            models.append(model_copy)
646
647
        return self.__class__(models=models)
648
649
    def select(
650
        self,
651
        name_substring=None,
652
        datasets_names=None,
653
        tag=None,
654
        model_type=None,
655
        frozen=None,
656
    ):
657
        """Select models that meet all specified conditions
658
659
        Parameters
660
        ----------
661
662
        name_substring : str
663
            Substring contained in the model name
664
        datasets_names : str or list
665
            Name of the dataset
666
        tag : str or list
667
            Model tag
668
        model_type : {None, spatial, spectral}
669
           Type of model, used together with "tag", if the tag is not unique.
670
        frozen : bool
671
            Select models with all parameters frozen if True, exclude them if False.
672
673
        Returns
674
        -------
675
        models : `DatasetModels`
676
            Selected models
677
        """
678
        mask = self.selection_mask(
679
            name_substring, datasets_names, tag, model_type, frozen
680
        )
681
        return self[mask]
682
683
    def selection_mask(
684
        self,
685
        name_substring=None,
686
        datasets_names=None,
687
        tag=None,
688
        model_type=None,
689
        frozen=None,
690
    ):
691
        """Create a mask of models, that meet all specified conditions
692
693
        Parameters
694
        ----------
695
        name_substring : str
696
            Substring contained in the model name
697
        datasets_names : str or list of str
698
            Name of the dataset
699
        tag : str or list of str
700
            Model tag
701
        model_type : {None, spatial, spectral}
702
           Type of model, used together with "tag", if the tag is not unique.
703
        frozen : bool
704
            Select models with all parameters frozen if True, exclude them if False.
705
706
        Returns
707
        -------
708
        mask : `numpy.array`
709
            Boolean mask, True for selected models
710
        """
711
        selection = np.ones(len(self), dtype=bool)
712
713
        if tag and not isinstance(tag, list):
714
            tag = [tag]
715
716
        if datasets_names and not isinstance(datasets_names, list):
717
            datasets_names = [datasets_names]
718
719
        for idx, model in enumerate(self):
720
            if name_substring:
721
                selection[idx] &= name_substring in model.name
722
723
            if datasets_names:
724
                selection[idx] &= model.datasets_names is None or np.any(
725
                    [name in model.datasets_names for name in datasets_names]
726
                )
727
728
            if tag:
729
                if model_type is None:
730
                    sub_model = model
731
                else:
732
                    sub_model = getattr(model, f"{model_type}_model", None)
733
734
                if sub_model:
735
                    selection[idx] &= np.any([t in sub_model.tag for t in tag])
736
                else:
737
                    selection[idx] &= False
738
739
            if frozen is not None:
740
                if frozen:
741
                    selection[idx] &= model.frozen
742
                else:
743
                    selection[idx] &= ~model.frozen
744
745
        return np.array(selection, dtype=bool)
746
747
    def select_mask(self, mask, margin="0 deg", use_evaluation_region=True):
748
        """Check if sky models contribute within a mask map.
749
750
        Parameters
751
        ----------
752
        mask : `~gammapy.maps.WcsNDMap` of boolean type
753
            Map containing a boolean mask
754
        margin : `~astropy.unit.Quantity`
755
            Add a margin in degree to the source evaluation radius.
756
            Used to take into account PSF width.
757
        use_evaluation_region : bool
758
            Account for the extension of the model or not. The default is True.
759
760
        Returns
761
        -------
762
        models : `DatasetModels`
763
            Selected models contributing inside the region where mask==True
764
        """
765
        models = []
766
767
        if not mask.geom.is_image:
768
            mask = mask.reduce_over_axes(func=np.logical_or)
769
770
        for model in self.select(tag="sky-model"):
771
            if use_evaluation_region:
772
                contributes = model.contributes(mask=mask, margin=margin)
773
            else:
774
                contributes = mask.get_by_coord(model.position, fill_value=0)
775
776
            if np.any(contributes):
777
                models.append(model)
778
779
        return self.__class__(models=models)
780
781
    def select_region(self, regions, wcs=None):
782
        """Select sky models with center position contained within a given region
783
784
        Parameters
785
        ----------
786
        regions : str, `~regions.Region` or list of `~regions.Region`
787
            Region or list of regions (pixel or sky regions accepted).
788
            A region can be defined as a string ind DS9 format as well.
789
            See http://ds9.si.edu/doc/ref/region.html for details.
790
        wcs : `~astropy.wcs.WCS`
791
            World coordinate system transformation
792
793
        Returns
794
        -------
795
        models : `DatasetModels`
796
            Selected models
797
        """
798
        geom = RegionGeom.from_regions(regions, wcs=wcs)
799
800
        models = []
801
802
        for model in self.select(tag="sky-model"):
803
            if geom.contains(model.position):
804
                models.append(model)
805
806
        return self.__class__(models=models)
807
808
    def restore_status(self, restore_values=True):
809
        """Context manager to restore status.
810
811
        A copy of the values is made on enter,
812
        and those values are restored on exit.
813
814
        Parameters
815
        ----------
816
        restore_values : bool
817
            Restore values if True,
818
            otherwise restore only frozen status and covariance matrix.
819
820
        """
821
        return restore_models_status(self, restore_values)
822
823
    def set_parameters_bounds(
824
        self, tag, model_type, parameters_names=None, min=None, max=None, value=None
825
    ):
826
        """Set bounds for the selected models types and parameters names
827
828
        Parameters
829
        ----------
830
        tag : str or list
831
            Tag of the models
832
        model_type :  {"spatial", "spectral", "temporal"}
833
            Type of model
834
        parameters_names : str or list
835
            parameters names
836
        min : float
837
            min value
838
        max : float
839
            max value
840
        value : float
841
            init value
842
        """
843
        models = self.select(tag=tag, model_type=model_type)
844
        parameters = models.parameters.select(name=parameters_names, type=model_type)
845
        n = len(parameters)
846
847
        if min is not None:
848
            parameters.min = np.ones(n) * min
849
        if max is not None:
850
            parameters.max = np.ones(n) * max
851
        if value is not None:
852
            parameters.value = np.ones(n) * value
853
854
    def freeze(self, model_type=None):
855
        """Freeze parameters depending on model type
856
857
        Parameters
858
        ----------
859
        model_type : {None, "spatial", "spectral"}
860
           freeze all parameters or only spatial or only spectral
861
        """
862
863
        for m in self:
864
            m.freeze(model_type)
865
866
    def unfreeze(self, model_type=None):
867
        """Restore parameters frozen status to default depending on model type
868
869
        Parameters
870
        ----------
871
        model_type : {None, "spatial", "spectral"}
872
           restore frozen status to default for all parameters or only spatial or only spectral
873
        """
874
875
        for m in self:
876
            m.unfreeze(model_type)
877
878
    @property
879
    def frozen(self):
880
        """Boolean mask, True if all parameters of a given model are frozen"""
881
        return np.all([m.frozen for m in self])
882
883
    def reassign(self, dataset_name, new_dataset_name):
884
        """Reassign a model from one dataset to another
885
886
        Parameters
887
        ----------
888
        dataset_name : str or list
889
            Name of the datasets where the model is currently defined
890
        new_dataset_name : str or list
891
            Name of the datasets where the model should be defined instead.
892
            If multiple names are given the two list must have the save length,
893
            as the reassignment is element-wise.
894
        """
895
        models = [m.reassign(dataset_name, new_dataset_name) for m in self]
896
        return self.__class__(models)
897
898
    def to_template_sky_model(self, geom, spectral_model=None, name=None):
899
        """Merge a list of models into a single `~gammapy.modeling.models.SkyModel`
900
901
        Parameters
902
        ----------
903
        geom : `Geom`
904
            Map geometry of the result template model.
905
        spectral_model : `~gammapy.modeling.models.SpectralModel`
906
            One of the NormSpectralMdel
907
        name : str
908
            Name of the new model
909
910
        Returns
911
        -------
912
        model : `SkyModel`
913
            Template sky model.
914
        """
915
        from . import PowerLawNormSpectralModel, SkyModel, TemplateSpatialModel
916
917
        unit = u.Unit("1 / (cm2 s sr TeV)")
918
        map_ = Map.from_geom(geom, unit=unit)
919
920
        for m in self:
921
            map_ += m.evaluate_geom(geom).to(unit)
922
923
        spatial_model = TemplateSpatialModel(map_, normalize=False)
924
925
        if spectral_model is None:
926
            spectral_model = PowerLawNormSpectralModel()
927
928
        return SkyModel(
929
            spectral_model=spectral_model, spatial_model=spatial_model, name=name
930
        )
931
932
    @property
933
    def positions(self):
934
        """Positions of the models (`~astropy.coordinates.SkyCoord`)"""
935
        positions = []
936
937
        for model in self.select(tag="sky-model"):
938
            if model.position:
939
                positions.append(model.position.icrs)
940
            else:
941
                log.warning(
942
                    f"Skipping model {model.name} - no spatial component present"
943
                )
944
945
        return SkyCoord(positions)
946
947
    def to_regions(self):
948
        """Returns a list of the regions for the spatial models
949
950
        Returns
951
        -------
952
        regions: list of `~regions.SkyRegion`
953
            Regions
954
        """
955
        regions = []
956
957
        for model in self.select(tag="sky-model"):
958
            try:
959
                region = model.spatial_model.to_region()
960
                regions.append(region)
961
            except AttributeError:
962
                log.warning(
963
                    f"Skipping model {model.name} - no spatial component present"
964
                )
965
        return regions
966
967
    @property
968
    def wcs_geom(self):
969
        """Minimum WCS geom in which all the models are contained"""
970
        regions = self.to_regions()
971
        try:
972
            return RegionGeom.from_regions(regions).to_wcs_geom()
973
        except IndexError:
974
            log.error("No spatial component in any model. Geom not defined")
975
976
    def plot_regions(self, ax=None, kwargs_point=None, path_effect=None, **kwargs):
977
        """Plot extent of the spatial models on a given wcs axis
978
979
        Parameters
980
        ----------
981
        ax : `~astropy.visualization.WCSAxes`
982
            Axes to plot on. If no axes are given, an all-sky wcs
983
            is chosen using a CAR projection
984
        kwargs_point : dict
985
            Keyword arguments passed to `~matplotlib.lines.Line2D` for plotting
986
            of point sources
987
        path_effect : `~matplotlib.patheffects.PathEffect`
988
            Path effect applied to artists and lines.
989
        **kwargs : dict
990
            Keyword arguments passed to `~matplotlib.artists.Artist`
991
992
993
        Returns
994
        -------
995
        ax : `~astropy.visualization.WcsAxes`
996
            WCS axes
997
        """
998
        regions = self.to_regions()
999
        geom = RegionGeom.from_regions(regions=regions)
1000
        return geom.plot_region(
1001
            ax=ax, kwargs_point=kwargs_point, path_effect=path_effect, **kwargs
1002
        )
1003
1004
    def plot_positions(self, ax=None, **kwargs):
1005
        """ "Plot the centers of the spatial models on a given wcs axis
1006
1007
        Parameters
1008
        ----------
1009
        ax : `~astropy.visualization.WCSAxes`
1010
            Axes to plot on. If no axes are given, an all-sky wcs
1011
            is chosen using a CAR projection
1012
        **kwargs : dict
1013
            Keyword arguments passed to `~matplotlib.pyplot.scatter`
1014
1015
1016
        Returns
1017
        -------
1018
        ax : `~astropy.visualization.WcsAxes`
1019
            Wcs axes
1020
        """
1021
        from astropy.visualization.wcsaxes import WCSAxes
1022
1023
        if ax is None or not isinstance(ax, WCSAxes):
1024
            ax = Map.from_geom(self.wcs_geom).plot()
1025
1026
        kwargs.setdefault("marker", "*")
1027
        kwargs.setdefault("color", "tab:blue")
1028
        path_effects = kwargs.get("path_effects", None)
1029
1030
        xp, yp = self.positions.to_pixel(ax.wcs)
1031
        p = ax.scatter(xp, yp, **kwargs)
1032
1033
        if path_effects:
1034
            plt.setp(p, path_effects=path_effects)
1035
1036
        return ax
1037
1038
1039
class Models(DatasetModels, collections.abc.MutableSequence):
1040
    """Sky model collection.
1041
1042
    Parameters
1043
    ----------
1044
    models : `SkyModel`, list of `SkyModel` or `Models`
1045
        Sky models
1046
    """
1047
1048
    def __delitem__(self, key):
1049
        del self._models[self.index(key)]
1050
1051
    def __setitem__(self, key, model):
1052
        from gammapy.modeling.models import FoVBackgroundModel, SkyModel
1053
1054
        if isinstance(model, (SkyModel, FoVBackgroundModel)):
1055
            self._models[self.index(key)] = model
1056
        else:
1057
            raise TypeError(f"Invalid type: {model!r}")
1058
1059
    def insert(self, idx, model):
1060
        if model.name in self.names:
1061
            raise (ValueError("Model names must be unique"))
1062
1063
        self._models.insert(idx, model)
1064
1065
1066
class restore_models_status:
1067
    def __init__(self, models, restore_values=True):
1068
        self.restore_values = restore_values
1069
        self.models = models
1070
        self.values = [_.value for _ in models.parameters]
1071
        self.frozen = [_.frozen for _ in models.parameters]
1072
        self.covariance_data = models.covariance.data
1073
1074
    def __enter__(self):
1075
        pass
1076
1077
    def __exit__(self, type, value, traceback):
1078
        for value, par, frozen in zip(self.values, self.models.parameters, self.frozen):
1079
            if self.restore_values:
1080
                par.value = value
1081
            par.frozen = frozen
1082
        self.models.covariance = self.covariance_data
1083