gammapy.modeling.parameter.Parameter.__init__()   B
last analyzed

Complexity

Conditions 4

Size

Total Lines 49
Code Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 42
nop 17
dl 0
loc 49
rs 8.872
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
"""Model parameter classes."""
3
import collections.abc
4
import copy
5
import itertools
6
import logging
7
import numpy as np
8
from astropy import units as u
9
from gammapy.utils.interpolation import interpolation_scale
10
from gammapy.utils.table import table_from_row_data
11
12
__all__ = ["Parameter", "Parameters"]
13
14
log = logging.getLogger(__name__)
15
16
17
def _get_parameters_str(parameters):
18
    str_ = ""
19
20
    for par in parameters:
21
        if par.name == "amplitude":
22
            value_format, error_format = "{:10.2e}", "{:7.1e}"
23
        else:
24
            value_format, error_format = "{:10.3f}", "{:7.2f}"
25
26
        line = "\t{:21} {:8}: " + value_format + "\t {} {:<12s}\n"
27
28
        if par._link_label_io is not None:
29
            name = par._link_label_io
30
        else:
31
            name = par.name
32
33
        if par.frozen:
34
            frozen, error = "(frozen)", "\t\t"
35
        else:
36
            frozen = ""
37
            try:
38
                error = "+/- " + error_format.format(par.error)
39
            except AttributeError:
40
                error = ""
41
        str_ += line.format(name, frozen, par.value, error, par.unit)
42
    return str_.expandtabs(tabsize=2)
43
44
45
class Parameter:
46
    """A model parameter.
47
48
    Note that the parameter value has been split into
49
    a factor and scale like this::
50
51
        value = factor x scale
52
53
    Users should interact with the ``value``, ``quantity``
54
    or ``min`` and ``max`` properties and consider the fact
55
    that there is a ``factor``` and ``scale`` an implementation detail.
56
57
    That was introduced for numerical stability in parameter and error
58
    estimation methods, only in the Gammapy optimiser interface do we
59
    interact with the ``factor``, ``factor_min`` and ``factor_max`` properties,
60
    i.e. the optimiser "sees" the well-scaled problem.
61
62
    Parameters
63
    ----------
64
    name : str
65
        Name
66
    value : float or `~astropy.units.Quantity`
67
        Value
68
    scale : float, optional
69
        Scale (sometimes used in fitting)
70
    unit : `~astropy.units.Unit` or str, optional
71
        Unit
72
    min : float, optional
73
        Minimum (sometimes used in fitting)
74
    max : float, optional
75
        Maximum (sometimes used in fitting)
76
    frozen : bool, optional
77
        Frozen? (used in fitting)
78
    error : float
79
        Parameter error
80
    scan_min : float
81
        Minimum value for the parameter scan. Overwrites scan_n_sigma.
82
    scan_max : float
83
        Minimum value for the parameter scan. Overwrites scan_n_sigma.
84
    scan_n_values: int
85
        Number of values to be used for the parameter scan.
86
    scan_n_sigma : int
87
        Number of sigmas to scan.
88
    scan_values: `numpy.array`
89
        Scan values. Overwrites all of the scan keywords before.
90
    scale_method : {'scale10', 'factor1', None}
91
        Method used to set ``factor`` and ``scale``
92
    interp : {"lin", "sqrt", "log"}
93
        Parameter scaling to use for the scan.
94
    is_norm : bool
95
        Whether the parameter represents the flux norm of the model.
96
    """
97
98
    def __init__(
99
        self,
100
        name,
101
        value,
102
        unit="",
103
        scale=1,
104
        min=np.nan,
105
        max=np.nan,
106
        frozen=False,
107
        error=0,
108
        scan_min=None,
109
        scan_max=None,
110
        scan_n_values=11,
111
        scan_n_sigma=2,
112
        scan_values=None,
113
        scale_method="scale10",
114
        interp="lin",
115
        is_norm=False,
116
    ):
117
        if not isinstance(name, str):
118
            raise TypeError(f"Name must be string, got '{type(name)}' instead")
119
120
        self._name = name
121
        self._link_label_io = None
122
        self.scale = scale
123
        self.min = min
124
        self.max = max
125
        self.frozen = frozen
126
        self._error = error
127
        self._is_norm = is_norm
128
        self._type = None
129
130
        # TODO: move this to a setter method that can be called from `__set__` also!
131
        # Having it here is bad: behaviour not clear if Quantity and `unit` is passed.
132
        if isinstance(value, u.Quantity) or isinstance(value, str):
133
            val = u.Quantity(value)
134
            self.value = val.value
135
            self.unit = val.unit
136
        else:
137
            self.factor = value
138
            self.unit = unit
139
140
        self.scan_min = scan_min
141
        self.scan_max = scan_max
142
        self.scan_values = scan_values
143
        self.scan_n_values = scan_n_values
144
        self.scan_n_sigma = scan_n_sigma
145
        self.interp = interp
146
        self.scale_method = scale_method
147
148
    def __get__(self, instance, owner):
149
        if instance is None:
150
            return self
151
152
        par = instance.__dict__[self.name]
153
        par._type = getattr(instance, "type", None)
154
        return par
155
156
    def __set__(self, instance, value):
157
        if isinstance(value, Parameter):
158
            instance.__dict__[self.name] = value
159
        else:
160
            par = instance.__dict__[self.name]
161
            raise TypeError(f"Cannot assign {value!r} to parameter {par!r}")
162
163
    def __set_name__(self, owner, name):
164
        if not self._name == name:
165
            raise ValueError(f"Expected parameter name '{name}', got {self._name}")
166
167
    @property
168
    def is_norm(self):
169
        """Whether the parameter represents the norm of the model"""
170
        return self._is_norm
171
172
    @property
173
    def type(self):
174
        return self._type
175
176
    @property
177
    def error(self):
178
        return self._error
179
180
    @error.setter
181
    def error(self, value):
182
        self._error = float(u.Quantity(value, unit=self.unit).value)
183
184
    @property
185
    def name(self):
186
        """Name (str)."""
187
        return self._name
188
189
    @property
190
    def factor(self):
191
        """Factor (float)."""
192
        return self._factor
193
194
    @factor.setter
195
    def factor(self, val):
196
        self._factor = float(val)
197
198
    @property
199
    def scale(self):
200
        """Scale (float)."""
201
        return self._scale
202
203
    @scale.setter
204
    def scale(self, val):
205
        self._scale = float(val)
206
207
    @property
208
    def unit(self):
209
        """Unit (`~astropy.units.Unit`)."""
210
        return self._unit
211
212
    @unit.setter
213
    def unit(self, val):
214
        self._unit = u.Unit(val)
215
216
    @property
217
    def min(self):
218
        """Minimum (float)."""
219
        return self._min
220
221
    @min.setter
222
    def min(self, val):
223
        "Astropy Table has masked values for NaN. Replacing with np.nan."
224
        if isinstance(val, np.ma.core.MaskedConstant):
225
            self._min = np.nan
226
        else:
227
            self._min = float(val)
228
229
    @property
230
    def factor_min(self):
231
        """Factor min (float).
232
233
        This ``factor_min = min / scale`` is for the optimizer interface.
234
        """
235
        return self.min / self.scale
236
237
    @property
238
    def max(self):
239
        """Maximum (float)."""
240
        return self._max
241
242
    @max.setter
243
    def max(self, val):
244
        "Astropy Table has masked values for NaN. Replacing with np.nan."
245
        if isinstance(val, np.ma.core.MaskedConstant):
246
            self._max = np.nan
247
        else:
248
            self._max = float(val)
249
250
    @property
251
    def factor_max(self):
252
        """Factor max (float).
253
254
        This ``factor_max = max / scale`` is for the optimizer interface.
255
        """
256
        return self.max / self.scale
257
258
    @property
259
    def scale_method(self):
260
        """Method used to set ``factor`` and ``scale``"""
261
        return self._scale_method
262
263
    @scale_method.setter
264
    def scale_method(self, val):
265
        if val not in ["scale10", "factor1"] and val is not None:
266
            raise ValueError(f"Invalid method: {val}")
267
        self._scale_method = val
268
269
    @property
270
    def frozen(self):
271
        """Frozen? (used in fitting) (bool)."""
272
        return self._frozen
273
274
    @frozen.setter
275
    def frozen(self, val):
276
        if val in ["True", "False"]:
277
            val = bool(val)
278
        if not isinstance(val, bool) and not isinstance(val, np.bool_):
279
            raise TypeError(f"Invalid type: {val}, {type(val)}")
280
        self._frozen = val
281
282
    @property
283
    def value(self):
284
        """Value = factor x scale (float)."""
285
        return self._factor * self._scale
286
287
    @value.setter
288
    def value(self, val):
289
        self.factor = float(val) / self._scale
290
291
    @property
292
    def quantity(self):
293
        """Value times unit (`~astropy.units.Quantity`)."""
294
        return self.value * self.unit
295
296
    @quantity.setter
297
    def quantity(self, val):
298
        val = u.Quantity(val)
299
300
        if not val.unit.is_equivalent(self.unit):
301
            raise u.UnitConversionError(
302
                f"Unit must be equivalent to {self.unit} for parameter {self.name}"
303
            )
304
305
        self.value = val.value
306
        self.unit = val.unit
307
308
    # TODO: possibly allow to set this independently
309
    @property
310
    def conf_min(self):
311
        """Confidence min value (`float`)
312
313
        Returns parameter minimum if defined else the scan_min
314
        """
315
        if not np.isnan(self.min):
316
            return self.min
317
        else:
318
            return self.scan_min
319
320
    # TODO: possibly allow to set this independently
321
    @property
322
    def conf_max(self):
323
        """Confidence max value (`float`)
324
325
        Returns parameter maximum if defined else the scan_max
326
        """
327
        if not np.isnan(self.max):
328
            return self.max
329
        else:
330
            return self.scan_max
331
332
    @property
333
    def scan_min(self):
334
        """Stat scan min"""
335
        if self._scan_min is None:
336
            return self.value - self.error * self.scan_n_sigma
337
338
        return self._scan_min
339
340
    @property
341
    def scan_max(self):
342
        """Stat scan max"""
343
        if self._scan_max is None:
344
            return self.value + self.error * self.scan_n_sigma
345
346
        return self._scan_max
347
348
    @scan_min.setter
349
    def scan_min(self, value):
350
        """Stat scan min setter"""
351
        self._scan_min = value
352
353
    @scan_max.setter
354
    def scan_max(self, value):
355
        """Stat scan max setter"""
356
        self._scan_max = value
357
358
    @property
359
    def scan_n_sigma(self):
360
        """Stat scan n sigma"""
361
        return self._scan_n_sigma
362
363
    @scan_n_sigma.setter
364
    def scan_n_sigma(self, n_sigma):
365
        """Stat scan n sigma"""
366
        self._scan_n_sigma = int(n_sigma)
367
368
    @property
369
    def scan_values(self):
370
        """Stat scan values (`~numpy.ndarray`)"""
371
        if self._scan_values is None:
372
            scale = interpolation_scale(self.interp)
373
            parmin, parmax = scale([self.scan_min, self.scan_max])
374
            values = np.linspace(parmin, parmax, self.scan_n_values)
375
            return scale.inverse(values)
376
377
        return self._scan_values
378
379
    @scan_values.setter
380
    def scan_values(self, values):
381
        """Set scan values"""
382
        self._scan_values = values
383
384
    def check_limits(self):
385
        """Emit a warning or error if value is outside the min/max range"""
386
        if not self.frozen:
387
            if (~np.isnan(self.min) and (self.value <= self.min)) or (
388
                ~np.isnan(self.max) and (self.value >= self.max)
389
            ):
390
                log.warning(
391
                    f"Value {self.value} is outside bounds [{self.min}, {self.max}]"
392
                    f" for parameter '{self.name}'"
393
                )
394
395
    def __repr__(self):
396
        return (
397
            f"{self.__class__.__name__}(name={self.name!r}, value={self.value!r}, "
398
            f"factor={self.factor!r}, scale={self.scale!r}, unit={self.unit!r}, "
399
            f"min={self.min!r}, max={self.max!r}, frozen={self.frozen!r}, id={hex(id(self))})"
400
        )
401
402
    def copy(self):
403
        """A deep copy"""
404
        return copy.deepcopy(self)
405
406
    def update_from_dict(self, data):
407
        """Update parameters from a dict.
408
        Protection against changing parameter model, type, name."""
409
        keys = ["value", "unit", "min", "max", "frozen"]
410
        for k in keys:
411
            setattr(self, k, data[k])
412
413
    def to_dict(self):
414
        """Convert to dict."""
415
        output = {
416
            "name": self.name,
417
            "value": self.value,
418
            "unit": self.unit.to_string("fits"),
419
            "error": self.error,
420
            "min": self.min,
421
            "max": self.max,
422
            "frozen": self.frozen,
423
            "interp": self.interp,
424
            "scale_method": self.scale_method,
425
            "is_norm": self.is_norm,
426
        }
427
428
        if self._link_label_io is not None:
429
            output["link"] = self._link_label_io
430
431
        return output
432
433
    def autoscale(self):
434
        """Autoscale the parameters.
435
436
        Set ``factor`` and ``scale`` according to ``scale_method`` attribute
437
438
        Available ``scale_method``
439
440
        * ``scale10`` sets ``scale`` to power of 10,
441
          so that abs(factor) is in the range 1 to 10
442
        * ``factor1`` sets ``factor, scale = 1, value``
443
444
        In both cases the sign of value is stored in ``factor``,
445
        i.e. the ``scale`` is always positive.
446
        If ``scale_method`` is None the scaling is ignored.
447
448
        """
449
        if self.scale_method == "scale10":
450
            value = self.value
451
            if value != 0:
452
                exponent = np.floor(np.log10(np.abs(value)))
453
                scale = np.power(10.0, exponent)
454
                self.factor = value / scale
455
                self.scale = scale
456
457
        elif self.scale_method == "factor1":
458
            self.factor, self.scale = 1, self.value
459
460
461
class Parameters(collections.abc.Sequence):
462
    """Parameters container.
463
464
    - List of `Parameter` objects.
465
    - Covariance matrix.
466
467
    Parameters
468
    ----------
469
    parameters : list of `Parameter`
470
        List of parameters
471
    """
472
473
    def __init__(self, parameters=None):
474
        if parameters is None:
475
            parameters = []
476
        else:
477
            parameters = list(parameters)
478
479
        self._parameters = parameters
480
481
    def check_limits(self):
482
        """Check parameter limits and emit a warning"""
483
        for par in self:
484
            par.check_limits()
485
486
    @property
487
    def types(self):
488
        """Parameter types"""
489
        return [par.type for par in self]
490
491
    @property
492
    def min(self):
493
        """Parameter mins (`numpy.ndarray`)."""
494
        return np.array([_.min for _ in self._parameters], dtype=np.float64)
495
496
    @min.setter
497
    def min(self, min_array):
498
        """Parameter minima (`numpy.ndarray`)."""
499
        if not len(self) == len(min_array):
500
            raise ValueError("Minima must have same length as parameter list")
501
502
        for min_, par in zip(min_array, self):
503
            par.min = min_
504
505
    @property
506
    def max(self):
507
        """Parameter maxima (`numpy.ndarray`)."""
508
        return np.array([_.max for _ in self._parameters], dtype=np.float64)
509
510
    @max.setter
511
    def max(self, max_array):
512
        """Parameter maxima (`numpy.ndarray`)."""
513
        if not len(self) == len(max_array):
514
            raise ValueError("Maxima must have same length as parameter list")
515
516
        for max_, par in zip(max_array, self):
517
            par.max = max_
518
519
    @property
520
    def value(self):
521
        """Parameter values (`numpy.ndarray`)."""
522
        return np.array([_.value for _ in self._parameters], dtype=np.float64)
523
524
    @value.setter
525
    def value(self, values):
526
        """Parameter values (`numpy.ndarray`)."""
527
        if not len(self) == len(values):
528
            raise ValueError("Values must have same length as parameter list")
529
530
        for value, par in zip(values, self):
531
            par.value = value
532
533
    @classmethod
534
    def from_stack(cls, parameters_list):
535
        """Create `Parameters` by stacking a list of other `Parameters` objects.
536
537
        Parameters
538
        ----------
539
        parameters_list : list of `Parameters`
540
            List of `Parameters` objects
541
        """
542
        pars = itertools.chain(*parameters_list)
543
        return cls(pars)
544
545
    def copy(self):
546
        """A deep copy"""
547
        return copy.deepcopy(self)
548
549
    @property
550
    def free_parameters(self):
551
        """List of free parameters"""
552
        return self.__class__([par for par in self._parameters if not par.frozen])
553
554
    @property
555
    def unique_parameters(self):
556
        """Unique parameters (`Parameters`)."""
557
        return self.__class__(dict.fromkeys(self._parameters))
558
559
    @property
560
    def names(self):
561
        """List of parameter names"""
562
        return [par.name for par in self._parameters]
563
564
    def index(self, val):
565
        """Get position index for a given parameter.
566
567
        The input can be a parameter object, parameter name (str)
568
        or if a parameter index (int) is passed in, it is simply returned.
569
        """
570
        if isinstance(val, int):
571
            return val
572
        elif isinstance(val, Parameter):
573
            return self._parameters.index(val)
574
        elif isinstance(val, str):
575
            for idx, par in enumerate(self._parameters):
576
                if val == par.name:
577
                    return idx
578
            raise IndexError(f"No parameter: {val!r}")
579
        else:
580
            raise TypeError(f"Invalid type: {type(val)!r}")
581
582
    def __getitem__(self, key):
583
        """Access parameter by name, index or boolean mask"""
584
        if isinstance(key, np.ndarray) and key.dtype == bool:
585
            return self.__class__(list(np.array(self._parameters)[key]))
586
        else:
587
            idx = self.index(key)
588
            return self._parameters[idx]
589
590
    def __len__(self):
591
        return len(self._parameters)
592
593
    def __add__(self, other):
594
        if isinstance(other, Parameters):
595
            return Parameters.from_stack([self, other])
596
        else:
597
            raise TypeError(f"Invalid type: {other!r}")
598
599
    def to_dict(self):
600
        data = []
601
602
        for par in self._parameters:
603
            data.append(par.to_dict())
604
605
        return data
606
607
    def to_table(self):
608
        """Convert parameter attributes to `~astropy.table.Table`."""
609
        rows = []
610
        for p in self._parameters:
611
            d = p.to_dict()
612
            if "link" not in d:
613
                d["link"] = ""
614
            for key in ["scale_method", "interp"]:
615
                if key in d:
616
                    del d[key]
617
            rows.append({**dict(type=p.type), **d})
618
        table = table_from_row_data(rows)
619
620
        table["value"].format = ".4e"
621
        for name in ["error", "min", "max"]:
622
            table[name].format = ".3e"
623
624
        return table
625
626
    def __eq__(self, other):
627
        all_equal = np.all([p is p_new for p, p_new in zip(self, other)])
628
        return all_equal and len(self) == len(other)
629
630
    @classmethod
631
    def from_dict(cls, data):
632
        parameters = []
633
634
        for par in data:
635
            link_label = par.pop("link", None)
636
            parameter = Parameter(**par)
637
            parameter._link_label_io = link_label
638
            parameters.append(parameter)
639
640
        return cls(parameters=parameters)
641
642
    def set_parameter_factors(self, factors):
643
        """Set factor of all parameters.
644
645
        Used in the optimizer interface.
646
        """
647
        idx = 0
648
        for parameter in self._parameters:
649
            if not parameter.frozen:
650
                parameter.factor = factors[idx]
651
                idx += 1
652
653
    def autoscale(self):
654
        """Autoscale all parameters.
655
656
        See :func:`~gammapy.modeling.Parameter.autoscale`
657
658
        """
659
        for par in self._parameters:
660
            par.autoscale()
661
662
    def select(
663
        self,
664
        name=None,
665
        type=None,
666
        frozen=None,
667
    ):
668
        """Create a mask of models, true if all conditions are verified
669
670
        Parameters
671
        ----------
672
        name : str or list
673
            Name of the parameter
674
        type : {None, spatial, spectral, temporal}
675
           type of models
676
        frozen : bool
677
            Select frozen parameters if True, exclude them if False.
678
679
        Returns
680
        -------
681
        parameters : `Parameters`
682
           Selected parameters
683
        """
684
        selection = np.ones(len(self), dtype=bool)
685
686
        if name and not isinstance(name, list):
687
            name = [name]
688
689
        for idx, par in enumerate(self):
690
            if name:
691
                selection[idx] &= np.any([_ == par.name for _ in name])
692
693
            if type:
694
                selection[idx] &= type == par.type
695
696
            if frozen is not None:
697
                if frozen:
698
                    selection[idx] &= par.frozen
699
                else:
700
                    selection[idx] &= ~par.frozen
701
702
        return self[selection]
703
704
    def freeze_all(self):
705
        """Freeze all parameters"""
706
        for par in self._parameters:
707
            par.frozen = True
708
709
    def unfreeze_all(self):
710
        """Unfreeze all parameters (even those frozen by default)"""
711
        for par in self._parameters:
712
            par.frozen = False
713
714
    def restore_status(self, restore_values=True):
715
        """Context manager to restore status.
716
717
        A copy of the values is made on enter,
718
        and those values are restored on exit.
719
720
        Parameters
721
        ----------
722
        restore_values : bool
723
            Restore values if True, otherwise restore only frozen status.
724
725
        Examples
726
        --------
727
        ::
728
729
            from gammapy.modeling.models import PowerLawSpectralModel
730
            pwl = PowerLawSpectralModel(index=2)
731
            with pwl.parameters.restore_status():
732
                pwl.parameters["index"].value = 3
733
            print(pwl.parameters["index"].value)
734
        """
735
        return restore_parameters_status(self, restore_values)
736
737
738
class restore_parameters_status:
739
    def __init__(self, parameters, restore_values=True):
740
        self.restore_values = restore_values
741
        self._parameters = parameters
742
        self.values = [_.value for _ in parameters]
743
        self.frozen = [_.frozen for _ in parameters]
744
745
    def __enter__(self):
746
        pass
747
748
    def __exit__(self, type, value, traceback):
749
        for value, par, frozen in zip(self.values, self._parameters, self.frozen):
750
            if self.restore_values:
751
                par.value = value
752
            par.frozen = frozen
753