Passed
Push — master ( de5fcf...a3522d )
by Axel
03:11 queued 11s
created

gammapy.modeling.models.core.ModelBase.reassign()   B

Complexity

Conditions 6

Size

Total Lines 36
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 36
rs 8.6666
c 0
b 0
f 0
cc 6
nop 3
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import logging
3
import collections.abc
4
import copy
5
from os.path import split
6
import yaml
7
import numpy as np
8
import astropy.units as u
9
from astropy.table import Table
10
from astropy.coordinates import SkyCoord
11
from regions import PointSkyRegion
12
from gammapy.modeling import Covariance, Parameter, Parameters
13
from gammapy.utils.scripts import make_name, make_path
14
from gammapy.maps import RegionGeom, Map
15
16
17
log = logging.getLogger(__name__)
18
19
20
def _set_link(shared_register, model):
21
    for param in model.parameters:
22
        name = param.name
23
        link_label = param._link_label_io
24
        if link_label is not None:
25
            if link_label in shared_register:
26
                new_param = shared_register[link_label]
27
                setattr(model, name, new_param)
28
            else:
29
                shared_register[link_label] = param
30
    return shared_register
31
32
def _get_model_class_from_dict(data):
33
    """get a model class from a dict"""
34
    from . import (
35
        MODEL_REGISTRY,
36
        SPATIAL_MODEL_REGISTRY,
37
        SPECTRAL_MODEL_REGISTRY,
38
        TEMPORAL_MODEL_REGISTRY,
39
    )
40
41
    if "type" in data:
42
        cls = MODEL_REGISTRY.get_cls(data["type"])
43
    elif "spatial" in data:
44
        cls = SPATIAL_MODEL_REGISTRY.get_cls(data["spatial"]["type"])
45
    elif "spectral" in data:
46
        cls = SPECTRAL_MODEL_REGISTRY.get_cls(data["spectral"]["type"])
47
    elif "temporal" in data:
48
        cls = TEMPORAL_MODEL_REGISTRY.get_cls(data["temporal"]["type"])
49
    return cls
0 ignored issues
show
introduced by
The variable cls does not seem to be defined for all execution paths.
Loading history...
50
51
52
__all__ = ["Model", "Models", "DatasetModels"]
53
54
55
class ModelBase:
56
    """Model base class."""
57
58
    _type = None
59
60
    def __init__(self, **kwargs):
61
        # Copy default parameters from the class to the instance
62
        default_parameters = self.default_parameters.copy()
63
64
        for par in default_parameters:
65
            value = kwargs.get(par.name, par)
66
67
            if not isinstance(value, Parameter):
68
                par.quantity = u.Quantity(value)
69
            else:
70
                par = value
71
72
            setattr(self, par.name, par)
73
        self._covariance = Covariance(self.parameters)
74
75
    def __getattribute__(self, name):
76
        value = object.__getattribute__(self, name)
77
78
        if isinstance(value, Parameter):
79
            return value.__get__(self, None)
80
81
        return value
82
83
    @property
84
    def type(self):
85
        return self._type
86
87
    def __init_subclass__(cls, **kwargs):
88
        # Add parameters list on the model sub-class (not instances)
89
        cls.default_parameters = Parameters(
90
            [_ for _ in cls.__dict__.values() if isinstance(_, Parameter)]
91
        )
92
93
    @classmethod
94
    def from_parameters(cls, parameters, **kwargs):
95
        """Create model from parameter list
96
97
        Parameters
98
        ----------
99
        parameters : `Parameters`
100
            Parameters for init
101
102
        Returns
103
        -------
104
        model : `Model`
105
            Model instance
106
        """
107
        for par in parameters:
108
            kwargs[par.name] = par
109
        return cls(**kwargs)
110
111
    def _check_covariance(self):
112
        if not self.parameters == self._covariance.parameters:
113
            self._covariance = Covariance(self.parameters)
114
115
    @property
116
    def covariance(self):
117
        self._check_covariance()
118
        for par in self.parameters:
119
            pars = Parameters([par])
120
            error = np.nan_to_num(par.error ** 2, nan=1)
121
            covar = Covariance(pars, data=[[error]])
122
            self._covariance.set_subcovariance(covar)
123
124
        return self._covariance
125
126
    @covariance.setter
127
    def covariance(self, covariance):
128
        self._check_covariance()
129
        self._covariance.data = covariance
130
131
        for par in self.parameters:
132
            pars = Parameters([par])
133
            variance = self._covariance.get_subcovariance(pars)
134
            par.error = np.sqrt(variance)
135
136
    @property
137
    def parameters(self):
138
        """Parameters (`~gammapy.modeling.Parameters`)"""
139
        return Parameters(
140
            [getattr(self, name) for name in self.default_parameters.names]
141
        )
142
143
    def copy(self):
144
        """A deep copy."""
145
        return copy.deepcopy(self)
146
147
    def to_dict(self, full_output=False):
148
        """Create dict for YAML serialisation"""
149
        tag = self.tag[0] if isinstance(self.tag, list) else self.tag
150
        params = self.parameters.to_dict()
151
152
        if not full_output:
153
            for par, par_default in zip(params, self.default_parameters):
154
                init = par_default.to_dict()
155
                for item in ["min", "max", "error", "interp", "scale_method"]:
156
                    default = init[item]
157
158
                    if par[item] == default or np.isnan(default):
159
                        del par[item]
160
161
                if not par["frozen"]:
162
                    del par["frozen"]
163
164
                if init["unit"] == "":
165
                    del par["unit"]
166
        data = {"type": tag, "parameters": params}
167
        if self._type is None:
168
            return data
169
        else:
170
            return {self._type: data}
171
172
    @classmethod
173
    def from_dict(cls, data):
174
        kwargs = {}
175
176
        par_data = []
177
        key0 = next(iter(data))
178
        if key0 in ["spatial", "temporal", "spectral"]:
179
            data = data[key0]
180
        if data["type"] not in cls.tag:
181
            raise ValueError(
182
                f"Invalid model type {data['type']} for Class {cls.__name__}"
183
            )
184
185
        input_names = [_["name"] for _ in data["parameters"]]
186
187
        for par in cls.default_parameters:
188
            par_dict = par.to_dict()
189
            try:
190
                index = input_names.index(par_dict["name"])
191
                par_dict.update(data["parameters"][index])
192
            except ValueError:
193
                log.warning(
194
                    f"Parameter {par_dict['name']} not defined. Using default value: {par_dict['value']} {par_dict['unit']}"
195
                )
196
            par_data.append(par_dict)
197
198
        parameters = Parameters.from_dict(par_data)
199
200
        # TODO: this is a special case for spatial models, maybe better move to `SpatialModel` base class
201
        if "frame" in data:
202
            kwargs["frame"] = data["frame"]
203
204
        return cls.from_parameters(parameters, **kwargs)
205
206
    def __str__(self):
207
        string = f"{self.__class__.__name__}\n"
208
        if len(self.parameters) > 0:
209
            string += f"\n{self.parameters.to_table()}"
210
        return string
211
212
    @property
213
    def frozen(self):
214
        """Frozen status of a model, True if all parameters are frozen """
215
        return np.all([p.frozen for p in self.parameters])
216
217
    def freeze(self):
218
        """Freeze all parameters"""
219
        self.parameters.freeze_all()
220
221
    def unfreeze(self):
222
        """Restore parameters frozen status to default"""
223
        for p, default in zip(self.parameters, self.default_parameters):
224
            p.frozen = default.frozen
225
226
    def reassign(self, datasets_names, new_datasets_names):
227
        """Reassign a model from one dataset to another
228
229
        Parameters
230
        ----------
231
        datasets_names : str or list
232
            Name of the datasets where the model is currently defined
233
        new_datasets_names : str or list
234
            Name of the datasets where the model should be defined instead.
235
            If multiple names are given the two list must have the save length,
236
            as the reassignment is element-wise.
237
238
        Returns
239
        -------
240
        model : `Model`
241
            Reassigned model.
242
243
        """
244
        model = self.copy(name=self.name)
245
246
        if not isinstance(datasets_names, list):
247
            datasets_names = [datasets_names]
248
249
        if not isinstance(new_datasets_names, list):
250
            new_datasets_names = [new_datasets_names]
251
252
        if isinstance(model.datasets_names, str):
253
            model.datasets_names = [model.datasets_names]
254
255
        if getattr(model, "datasets_names", None):
256
            for name, name_new in zip(datasets_names, new_datasets_names):
257
                model.datasets_names = [
258
                    _.replace(name, name_new) for _ in model.datasets_names
259
                ]
260
261
        return model
262
263
264
class Model:
265
    """Model class that contains only methods to create a model listed in the registries."""
266
267
    @staticmethod
268
    def create(tag, model_type=None, *args, **kwargs):
269
        """Create a model instance.
270
271
        Examples
272
        --------
273
        >>> from gammapy.modeling.models import Model
274
        >>> spectral_model = Model.create("pl-2", model_type="spectral", amplitude="1e-10 cm-2 s-1", index=3)
275
        >>> type(spectral_model)
276
        <class 'gammapy.modeling.models.spectral.PowerLaw2SpectralModel'>
277
        """
278
        
279
        data = {"type":tag}
280
        if model_type is not None:
281
            data = {model_type:data}
282
283
        cls = _get_model_class_from_dict(data)
284
        return cls(*args, **kwargs)
285
286
    @staticmethod
287
    def from_dict(data):
288
        """Create a model instance from a dict"""
289
290
        cls = _get_model_class_from_dict(data)
291
        return cls.from_dict(data)
292
293
294
295
class DatasetModels(collections.abc.Sequence):
296
    """Immutable models container
297
298
    Parameters
299
    ----------
300
    models : `SkyModel`, list of `SkyModel` or `Models`
301
        Sky models
302
    """
303
304
    def __init__(self, models=None):
305
        if models is None:
306
            models = []
307
308
        if isinstance(models, (Models, DatasetModels)):
309
            models = models._models
310
        elif isinstance(models, ModelBase):
311
            models = [models]
312
        elif not isinstance(models, list):
313
            raise TypeError(f"Invalid type: {models!r}")
314
315
        unique_names = []
316
        for model in models:
317
            if model.name in unique_names:
318
                raise (ValueError("Model names must be unique"))
319
            unique_names.append(model.name)
320
321
        self._models = models
322
        self._covar_file = None
323
        self._covariance = Covariance(self.parameters)
324
325
    def _check_covariance(self):
326
        if not self.parameters == self._covariance.parameters:
327
            self._covariance = Covariance.from_stack(
328
                [model.covariance for model in self._models]
329
            )
330
331
    @property
332
    def covariance(self):
333
        self._check_covariance()
334
335
        for model in self._models:
336
            self._covariance.set_subcovariance(model.covariance)
337
338
        return self._covariance
339
340
    @covariance.setter
341
    def covariance(self, covariance):
342
        self._check_covariance()
343
        self._covariance.data = covariance
344
345
        for model in self._models:
346
            subcovar = self._covariance.get_subcovariance(model.covariance.parameters)
347
            model.covariance = subcovar
348
349
    @property
350
    def parameters(self):
351
        return Parameters.from_stack([_.parameters for _ in self._models])
352
353
    @property
354
    def parameters_unique_names(self):
355
        """List of unique parameter names as model_name.par_type.par_name"""
356
        names = []
357
        for model in self:
358
            for par in model.parameters:
359
                components = [model.name, par.type, par.name]
360
                name = ".".join(components)
361
                names.append(name)
362
363
        return names
364
365
    @property
366
    def names(self):
367
        return [m.name for m in self._models]
368
369
    @classmethod
370
    def read(cls, filename):
371
        """Read from YAML file."""
372
        yaml_str = make_path(filename).read_text()
373
        path, filename = split(filename)
374
        return cls.from_yaml(yaml_str, path=path)
375
376
    @classmethod
377
    def from_yaml(cls, yaml_str, path=""):
378
        """Create from YAML string."""
379
        data = yaml.safe_load(yaml_str)
380
        return cls.from_dict(data, path=path)
381
382
    @classmethod
383
    def from_dict(cls, data, path=""):
384
        """Create from dict."""
385
        from . import MODEL_REGISTRY, SkyModel
386
387
        models = []
388
389
        for component in data["components"]:
390
            model_cls = MODEL_REGISTRY.get_cls(component["type"])
391
            model = model_cls.from_dict(component)
392
            models.append(model)
393
394
        models = cls(models)
395
396
        if "covariance" in data:
397
            filename = data["covariance"]
398
            path = make_path(path)
399
            if not (path / filename).exists():
400
                path, filename = split(filename)
401
402
            models.read_covariance(path, filename, format="ascii.fixed_width")
403
404
        shared_register = {}
405
        for model in models:
406
            if isinstance(model, SkyModel):
407
                submodels = [
408
                    model.spectral_model,
409
                    model.spatial_model,
410
                    model.temporal_model,
411
                ]
412
                for submodel in submodels:
413
                    if submodel is not None:
414
                        shared_register = _set_link(shared_register, submodel)
415
            else:
416
                shared_register = _set_link(shared_register, model)
417
        return models
418
419
    def write(
420
        self,
421
        path,
422
        overwrite=False,
423
        full_output=False,
424
        overwrite_templates=False,
425
        write_covariance=True,
426
    ):
427
        """Write to YAML file.
428
429
        Parameters
430
        ----------
431
        path : `pathlib.Path` or str
432
            path to write files
433
        overwrite : bool
434
            overwrite YAML files
435
        overwrite_templates : bool
436
            overwrite templates FITS files
437
        write_covariance : bool
438
            save covariance or not
439
        """
440
        base_path, _ = split(path)
441
        path = make_path(path)
442
        base_path = make_path(base_path)
443
444
        if path.exists() and not overwrite:
445
            raise IOError(f"File exists already: {path}")
446
447
        if (
448
            write_covariance
449
            and self.covariance is not None
450
            and len(self.parameters) != 0
451
        ):
452
            filecovar = path.stem + "_covariance.dat"
453
            kwargs = dict(
454
                format="ascii.fixed_width", delimiter="|", overwrite=overwrite
455
            )
456
            self.write_covariance(base_path / filecovar, **kwargs)
457
            self._covar_file = filecovar
458
459
        path.write_text(self.to_yaml(full_output, overwrite_templates))
460
461
    def to_yaml(self, full_output=False, overwrite_templates=False):
462
        """Convert to YAML string."""
463
        data = self.to_dict(full_output, overwrite_templates)
464
        return yaml.dump(
465
            data, sort_keys=False, indent=4, width=80, default_flow_style=False
466
        )
467
468
    def to_dict(self, full_output=False, overwrite_templates=False):
469
        """Convert to dict."""
470
        # update linked parameters labels
471
        params_list = []
472
        params_shared = []
473
        for param in self.parameters:
474
            if param not in params_list:
475
                params_list.append(param)
476
                params_list.append(param)
477
            elif param not in params_shared:
478
                params_shared.append(param)
479
        for param in params_shared:
480
            param._link_label_io = param.name + "@" + make_name()
481
482
        models_data = []
483
        for model in self._models:
484
            model_data = model.to_dict(full_output)
485
            models_data.append(model_data)
486
            if (
487
                hasattr(model, "spatial_model")
488
                and model.spatial_model is not None
489
                and "template" in model.spatial_model.tag
490
            ):
491
                model.spatial_model.write(overwrite=overwrite_templates)
492
493
        if self._covar_file is not None:
494
            return {
495
                "components": models_data,
496
                "covariance": str(self._covar_file),
497
            }
498
        else:
499
            return {"components": models_data}
500
501
    def to_parameters_table(self):
502
        """Convert Models parameters to an astropy Table."""
503
        table = self.parameters.to_table()
504
        # Warning: splitting of parameters will break is source name has a "." in its name.
505
        model_name = [name.split(".")[0] for name in self.parameters_unique_names]
506
        table.add_column(model_name, name="model", index=0)
507
        self._table_cached = table
508
        return table
509
510
    def update_parameters_from_table(self, t):
511
        """Update Models from an astropy Table."""
512
        parameters_dict = [dict(zip(t.colnames, row)) for row in t]
513
        for k, data in enumerate(parameters_dict):
514
            self.parameters[k].update_from_dict(data)
515
516
    def read_covariance(self, path, filename="_covariance.dat", **kwargs):
517
        """Read covariance data from file
518
519
        Parameters
520
        ----------
521
        filename : str
522
            Filename
523
        **kwargs : dict
524
            Keyword arguments passed to `~astropy.table.Table.read`
525
526
        """
527
        path = make_path(path)
528
        filepath = str(path / filename)
529
        t = Table.read(filepath, **kwargs)
530
        t.remove_column("Parameters")
531
        arr = np.array(t)
532
        data = arr.view(float).reshape(arr.shape + (-1,))
533
        self.covariance = data
534
        self._covar_file = filename
535
536
    def write_covariance(self, filename, **kwargs):
537
        """Write covariance to file
538
539
        Parameters
540
        ----------
541
        filename : str
542
            Filename
543
        **kwargs : dict
544
            Keyword arguments passed to `~astropy.table.Table.write`
545
546
        """
547
        names = self.parameters_unique_names
548
        table = Table()
549
        table["Parameters"] = names
550
551
        for idx, name in enumerate(names):
552
            values = self.covariance.data[idx]
553
            table[name] = values
554
555
        table.write(make_path(filename), **kwargs)
556
557
    def __str__(self):
558
        str_ = f"{self.__class__.__name__}\n\n"
559
560
        for idx, model in enumerate(self):
561
            str_ += f"Component {idx}: "
562
            str_ += str(model)
563
564
        return str_.expandtabs(tabsize=2)
565
566
    def __add__(self, other):
567
        if isinstance(other, (Models, list)):
568
            return Models([*self, *other])
569
        elif isinstance(other, ModelBase):
570
            if other.name in self.names:
571
                raise (ValueError("Model names must be unique"))
572
            return Models([*self, other])
573
        else:
574
            raise TypeError(f"Invalid type: {other!r}")
575
576
    def __getitem__(self, key):
577
        if isinstance(key, np.ndarray) and key.dtype == bool:
578
            return self.__class__(list(np.array(self._models)[key]))
579
        else:
580
            return self._models[self.index(key)]
581
582
    def index(self, key):
583
        if isinstance(key, (int, slice)):
584
            return key
585
        elif isinstance(key, str):
586
            return self.names.index(key)
587
        elif isinstance(key, ModelBase):
588
            return self._models.index(key)
589
        else:
590
            raise TypeError(f"Invalid type: {type(key)!r}")
591
592
    def __len__(self):
593
        return len(self._models)
594
595
    def _ipython_key_completions_(self):
596
        return self.names
597
598
    def copy(self):
599
        """A deep copy."""
600
        return copy.deepcopy(self)
601
602
    def select(
603
        self,
604
        name_substring=None,
605
        datasets_names=None,
606
        tag=None,
607
        model_type=None,
608
        frozen=None,
609
    ):
610
        """Select models that meet all specified conditions
611
612
        Parameters
613
        ----------
614
615
        name_substring : str
616
            Substring contained in the model name
617
        datasets_names : str or list
618
            Name of the dataset
619
        tag : str or list
620
            Model tag
621
        model_type : {None, spatial, spectral}
622
           Type of model, used together with "tag", if the tag is not unique.
623
        frozen : bool
624
            Select models with all parameters frozen if True, exclude them if False.
625
626
        Returns
627
        -------
628
        models : `DatasetModels`
629
            Selected models
630
        """
631
        mask = self.selection_mask(
632
            name_substring, datasets_names, tag, model_type, frozen
633
        )
634
        return self[mask]
635
636
    def selection_mask(
637
        self,
638
        name_substring=None,
639
        datasets_names=None,
640
        tag=None,
641
        model_type=None,
642
        frozen=None,
643
    ):
644
        """Create a mask of models, that meet all specified conditions
645
646
        Parameters
647
        ----------
648
        name_substring : str
649
            Substring contained in the model name
650
        datasets_names : str or list of str
651
            Name of the dataset
652
        tag : str or list of str
653
            Model tag
654
        model_type : {None, spatial, spectral}
655
           Type of model, used together with "tag", if the tag is not unique.
656
        frozen : bool
657
            Select models with all parameters frozen if True, exclude them if False.
658
659
        Returns
660
        -------
661
        mask : `numpy.array`
662
            Boolean mask, True for selected models
663
        """
664
        selection = np.ones(len(self), dtype=bool)
665
666
        if tag and not isinstance(tag, list):
667
            tag = [tag]
668
669
        if datasets_names and not isinstance(datasets_names, list):
670
            datasets_names = [datasets_names]
671
672
        for idx, model in enumerate(self):
673
            if name_substring:
674
                selection[idx] &= name_substring in model.name
675
676
            if datasets_names:
677
                selection[idx] &= model.datasets_names is None or np.any(
678
                    [name in model.datasets_names for name in datasets_names]
679
                )
680
681
            if tag:
682
                if model_type is None:
683
                    sub_model = model
684
                else:
685
                    sub_model = getattr(model, f"{model_type}_model", None)
686
687
                if sub_model:
688
                    selection[idx] &= np.any([t in sub_model.tag for t in tag])
689
                else:
690
                    selection[idx] &= False
691
692
            if frozen is not None:
693
                if frozen:
694
                    selection[idx] &= model.frozen
695
                else:
696
                    selection[idx] &= ~model.frozen
697
698
        return np.array(selection, dtype=bool)
699
700
    def select_mask(self, mask, margin="0 deg", use_evaluation_region=True):
701
        """Check if sky models contribute within a mask map.
702
703
        Parameters
704
        ----------
705
        mask : `~gammapy.maps.WcsNDMap` of boolean type
706
            Map containing a boolean mask
707
        margin : `~astropy.unit.Quantity`
708
            Add a margin in degree to the source evaluation radius.
709
            Used to take into account PSF width.
710
        use_evaluation_region : bool
711
            Account for the extension of the model or not. The default is True.
712
713
        Returns
714
        -------
715
        models : `DatasetModels`
716
            Selected models contributing inside the region where mask==True
717
        """
718
        models = []
719
720
        if not mask.geom.is_image:
721
            mask = mask.reduce_over_axes(func=np.logical_or)
722
723
        for model in self.select(tag="sky-model"):
724
            if use_evaluation_region:
725
                contributes = model.contributes(mask=mask, margin=margin)
726
            else:
727
                contributes = mask.get_by_coord(model.position, fill_value=0)
728
729
            if np.any(contributes):
730
                models.append(model)
731
732
        return self.__class__(models=models)
733
734
    def select_region(self, regions, wcs=None):
735
        """Select sky models with center position contained within a given region
736
737
        Parameters
738
        ----------
739
        regions : str, `~regions.Region` or list of `~regions.Region`
740
            Region or list of regions (pixel or sky regions accepted).
741
            A region can be defined as a string ind DS9 format as well.
742
            See http://ds9.si.edu/doc/ref/region.html for details.
743
        wcs : `~astropy.wcs.WCS`
744
            World coordinate system transformation
745
746
        Returns
747
        -------
748
        models : `DatasetModels`
749
            Selected models
750
        """
751
        geom = RegionGeom.from_regions(regions, wcs=wcs)
752
753
        models = []
754
755
        for model in self.select(tag="sky-model"):
756
            if geom.contains(model.position):
757
                models.append(model)
758
759
        return self.__class__(models=models)
760
761
    def restore_status(self, restore_values=True):
762
        """Context manager to restore status.
763
764
        A copy of the values is made on enter,
765
        and those values are restored on exit.
766
767
        Parameters
768
        ----------
769
        restore_values : bool
770
            Restore values if True,
771
            otherwise restore only frozen status and covariance matrix.
772
773
        """
774
        return restore_models_status(self, restore_values)
775
776
    def set_parameters_bounds(
777
        self, tag, model_type, parameters_names, min=None, max=None, value=None
778
    ):
779
        """Set bounds for the selected models types and parameters names
780
781
        Parameters
782
        ----------
783
        tag : str or list
784
            tag of the models
785
        model_type : {"spatial", "spectral"}
786
            type of models
787
        parameters_names : str or list
788
            parameters names
789
        min : float
790
            min value
791
        max : float
792
            max value
793
        value : float
794
            init value
795
        """
796
797
        models = self.select(tag=tag, model_type=model_type)
798
        parameters = models.parameters.select(name=parameters_names, type=model_type)
799
        n = len(parameters)
800
801
        if min is not None:
802
            parameters.min = np.ones(n) * min
803
        if max is not None:
804
            parameters.max = np.ones(n) * max
805
        if value is not None:
806
            parameters.value = np.ones(n) * value
807
808
    def freeze(self, model_type=None):
809
        """Freeze parameters depending on model type
810
811
        Parameters
812
        ----------
813
        model_type : {None, "spatial", "spectral"}
814
           freeze all parameters or only spatial or only spectral
815
        """
816
817
        for m in self:
818
            m.freeze(model_type)
819
820
    def unfreeze(self, model_type=None):
821
        """Restore parameters frozen status to default depending on model type
822
823
        Parameters
824
        ----------
825
        model_type : {None, "spatial", "spectral"}
826
           restore frozen status to default for all parameters or only spatial or only spectral
827
        """
828
829
        for m in self:
830
            m.unfreeze(model_type)
831
832
    @property
833
    def frozen(self):
834
        """Boolean mask, True if all parameters of a given model are frozen"""
835
        return np.all([m.frozen for m in self])
836
837
    def reassign(self, dataset_name, new_dataset_name):
838
        """Reassign a model from one dataset to another
839
840
        Parameters
841
        ----------
842
        dataset_name : str or list
843
            Name of the datasets where the model is currently defined
844
        new_dataset_name : str or list
845
            Name of the datasets where the model should be defined instead.
846
            If multiple names are given the two list must have the save length,
847
            as the reassignment is element-wise.
848
        """
849
        models = [m.reassign(dataset_name, new_dataset_name) for m in self]
850
        return self.__class__(models)
851
852
    def to_template_sky_model(self, geom, spectral_model=None, name=None):
853
        """Merge a list of models into a single `~gammapy.modeling.models.SkyModel`
854
855
        Parameters
856
        ----------
857
        spectral_model : `~gammapy.modeling.models.SpectralModel`
858
            One of the NormSpectralMdel
859
        name : str
860
            Name of the new model
861
862
        """
863
        from . import SkyModel, TemplateSpatialModel, PowerLawNormSpectralModel
864
865
        unit = u.Unit("1 / (cm2 s sr TeV)")
866
        map_ = Map.from_geom(geom, unit=unit)
867
        for m in self:
868
            map_ += m.evaluate_geom(geom).to(unit)
869
        spatial_model = TemplateSpatialModel(map_, normalize=False)
870
        if spectral_model is None:
871
            spectral_model = PowerLawNormSpectralModel()
872
        return SkyModel(
873
            spectral_model=spectral_model, spatial_model=spatial_model, name=name
874
        )
875
876
    @property
877
    def positions(self):
878
        """Positions of the models (`SkyCoord`)"""
879
        positions = []
880
881
        for model in self.select(tag="sky-model"):
882
            if model.position:
883
                positions.append(model.position)
884
            else:
885
                log.warning(
886
                    f"Skipping model {model.name} - no spatial component present"
887
                )
888
889
        return SkyCoord(positions)
890
891
    def to_regions(self):
892
        """Returns a list of the regions for the spatial models
893
894
        Returns
895
        -------
896
        regions: list of `~regions.SkyRegion`
897
            Regions
898
        """
899
        regions = []
900
901
        for model in self.select(tag="sky-model"):
902
            try:
903
                region = model.spatial_model.to_region()
904
                regions.append(region)
905
            except AttributeError:
906
                log.warning(
907
                    f"Skipping model {model.name} - no spatial component present"
908
                )
909
        return regions
910
911
    @property
912
    def wcs_geom(self):
913
        """Minimum WCS geom in which all the models are contained """
914
        regions = self.to_regions()
915
        try:
916
            return RegionGeom.from_regions(regions).to_wcs_geom()
917
        except IndexError:
918
            log.error("No spatial component in any model. Geom not defined")
919
920
    def plot_regions(self, ax=None, kwargs_point=None, path_effect=None, **kwargs):
921
        """ Plot extent of the spatial models on a given wcs axis
922
923
        Parameters
924
        ----------
925
        ax : `~astropy.visualization.WCSAxes`
926
            Axes to plot on. If no axes are given, an all-sky wcs
927
            is chosen using a CAR projection
928
        kwargs_point : dict
929
            Keyword arguments passed to `~matplotlib.lines.Line2D` for plotting
930
            of point sources
931
        path_effect : `~matplotlib.patheffects.PathEffect`
932
            Path effect applied to artists and lines.
933
        **kwargs : dict
934
            Keyword arguments passed to `~matplotlib.artists.Artist`
935
936
937
        Returns
938
        -------
939
        ax : `~astropy.visualization.WcsAxes`
940
            WCS axes
941
        """
942
        from astropy.visualization.wcsaxes import WCSAxes
943
944
        kwargs_point = kwargs_point or {}
945
946
        if ax is None or not isinstance(ax, WCSAxes):
947
            ax = Map.from_geom(self.wcs_geom).plot()
948
949
        kwargs.setdefault("color", "tab:blue")
950
        kwargs.setdefault("fc", "None")
951
        kwargs_point.setdefault("marker", "*")
952
        kwargs_point.setdefault("markersize", 10)
953
        kwargs_point.setdefault("markeredgecolor", "None")
954
        kwargs_point.setdefault("color", kwargs["color"])
955
956
        for region in self.to_regions():
957
            if isinstance(region, PointSkyRegion):
958
                artist = region.to_pixel(ax.wcs).as_artist(**kwargs_point)
959
            else:
960
                artist = region.to_pixel(ax.wcs).as_artist(**kwargs)
961
962
            if path_effect:
963
                artist.set_path_effects([path_effect])
964
965
            ax.add_artist(artist)
966
967
        return ax
968
969
    def plot_positions(self, ax=None, **kwargs):
970
        """"Plot the centers of the spatial models on a given wcs axis
971
972
        Parameters
973
        ----------
974
        ax : `~astropy.visualization.WCSAxes`
975
            Axes to plot on. If no axes are given, an all-sky wcs
976
            is chosen using a CAR projection
977
        **kwargs : dict
978
            Keyword arguments passed to `~matplotlib.pyplot.scatter`
979
980
981
        Returns
982
        -------
983
        ax : `~astropy.visualization.WcsAxes`
984
            Wcs axes
985
        """
986
        from astropy.visualization.wcsaxes import WCSAxes
987
        import matplotlib.pyplot as plt
988
989
        if ax is None or not isinstance(ax, WCSAxes):
990
            ax = Map.from_geom(self.wcs_geom).plot()
991
992
        kwargs.setdefault("marker", "*")
993
        kwargs.setdefault("color", "tab:blue")
994
        path_effects = kwargs.get("path_effects", None)
995
996
        xp, yp = self.positions.to_pixel(ax.wcs)
997
        p = ax.scatter(xp, yp, **kwargs)
998
999
        if path_effects:
1000
            plt.setp(p, path_effects=path_effects)
1001
1002
        return ax
1003
1004
1005
class Models(DatasetModels, collections.abc.MutableSequence):
1006
    """Sky model collection.
1007
1008
    Parameters
1009
    ----------
1010
    models : `SkyModel`, list of `SkyModel` or `Models`
1011
        Sky models
1012
    """
1013
1014
    def __delitem__(self, key):
1015
        del self._models[self.index(key)]
1016
1017
    def __setitem__(self, key, model):
1018
        from gammapy.modeling.models import SkyModel, FoVBackgroundModel
1019
1020
        if isinstance(model, (SkyModel, FoVBackgroundModel)):
1021
            self._models[self.index(key)] = model
1022
        else:
1023
            raise TypeError(f"Invalid type: {model!r}")
1024
1025
    def insert(self, idx, model):
1026
        if model.name in self.names:
1027
            raise (ValueError("Model names must be unique"))
1028
1029
        self._models.insert(idx, model)
1030
1031
1032
class restore_models_status:
1033
    def __init__(self, models, restore_values=True):
1034
        self.restore_values = restore_values
1035
        self.models = models
1036
        self.values = [_.value for _ in models.parameters]
1037
        self.frozen = [_.frozen for _ in models.parameters]
1038
        self.covariance_data = models.covariance.data
1039
1040
    def __enter__(self):
1041
        pass
1042
1043
    def __exit__(self, type, value, traceback):
1044
        for value, par, frozen in zip(self.values, self.models.parameters, self.frozen):
1045
            if self.restore_values:
1046
                par.value = value
1047
            par.frozen = frozen
1048
        self.models.covariance = self.covariance_data
1049