1
|
|
|
# Licensed under a 3-clause BSD style license - see LICENSE.rst |
2
|
|
|
"""Model parameter classes.""" |
3
|
|
|
import collections.abc |
4
|
|
|
import copy |
5
|
|
|
import itertools |
6
|
|
|
import logging |
7
|
|
|
import numpy as np |
8
|
|
|
from astropy import units as u |
9
|
|
|
from gammapy.utils.interpolation import interpolation_scale |
10
|
|
|
from gammapy.utils.table import table_from_row_data |
11
|
|
|
|
12
|
|
|
__all__ = ["Parameter", "Parameters"] |
13
|
|
|
|
14
|
|
|
log = logging.getLogger(__name__) |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
def _get_parameters_str(parameters): |
18
|
|
|
str_ = "" |
19
|
|
|
|
20
|
|
|
for par in parameters: |
21
|
|
|
if par.name == "amplitude": |
22
|
|
|
value_format, error_format = "{:10.2e}", "{:7.1e}" |
23
|
|
|
else: |
24
|
|
|
value_format, error_format = "{:10.3f}", "{:7.2f}" |
25
|
|
|
|
26
|
|
|
line = "\t{:21} {:8}: " + value_format + "\t {} {:<12s}\n" |
27
|
|
|
|
28
|
|
|
if par._link_label_io is not None: |
29
|
|
|
name = par._link_label_io |
30
|
|
|
else: |
31
|
|
|
name = par.name |
32
|
|
|
|
33
|
|
|
if par.frozen: |
34
|
|
|
frozen, error = "(frozen)", "\t\t" |
35
|
|
|
else: |
36
|
|
|
frozen = "" |
37
|
|
|
try: |
38
|
|
|
error = "+/- " + error_format.format(par.error) |
39
|
|
|
except AttributeError: |
40
|
|
|
error = "" |
41
|
|
|
str_ += line.format(name, frozen, par.value, error, par.unit) |
42
|
|
|
return str_.expandtabs(tabsize=2) |
43
|
|
|
|
44
|
|
|
|
45
|
|
|
class Parameter: |
46
|
|
|
"""A model parameter. |
47
|
|
|
|
48
|
|
|
Note that the parameter value has been split into |
49
|
|
|
a factor and scale like this:: |
50
|
|
|
|
51
|
|
|
value = factor x scale |
52
|
|
|
|
53
|
|
|
Users should interact with the ``value``, ``quantity`` |
54
|
|
|
or ``min`` and ``max`` properties and consider the fact |
55
|
|
|
that there is a ``factor``` and ``scale`` an implementation detail. |
56
|
|
|
|
57
|
|
|
That was introduced for numerical stability in parameter and error |
58
|
|
|
estimation methods, only in the Gammapy optimiser interface do we |
59
|
|
|
interact with the ``factor``, ``factor_min`` and ``factor_max`` properties, |
60
|
|
|
i.e. the optimiser "sees" the well-scaled problem. |
61
|
|
|
|
62
|
|
|
Parameters |
63
|
|
|
---------- |
64
|
|
|
name : str |
65
|
|
|
Name |
66
|
|
|
value : float or `~astropy.units.Quantity` |
67
|
|
|
Value |
68
|
|
|
scale : float, optional |
69
|
|
|
Scale (sometimes used in fitting) |
70
|
|
|
unit : `~astropy.units.Unit` or str, optional |
71
|
|
|
Unit |
72
|
|
|
min : float, optional |
73
|
|
|
Minimum (sometimes used in fitting) |
74
|
|
|
max : float, optional |
75
|
|
|
Maximum (sometimes used in fitting) |
76
|
|
|
frozen : bool, optional |
77
|
|
|
Frozen? (used in fitting) |
78
|
|
|
error : float |
79
|
|
|
Parameter error |
80
|
|
|
scan_min : float |
81
|
|
|
Minimum value for the parameter scan. Overwrites scan_n_sigma. |
82
|
|
|
scan_max : float |
83
|
|
|
Minimum value for the parameter scan. Overwrites scan_n_sigma. |
84
|
|
|
scan_n_values: int |
85
|
|
|
Number of values to be used for the parameter scan. |
86
|
|
|
scan_n_sigma : int |
87
|
|
|
Number of sigmas to scan. |
88
|
|
|
scan_values: `numpy.array` |
89
|
|
|
Scan values. Overwrites all of the scan keywords before. |
90
|
|
|
scale_method : {'scale10', 'factor1', None} |
91
|
|
|
Method used to set ``factor`` and ``scale`` |
92
|
|
|
interp : {"lin", "sqrt", "log"} |
93
|
|
|
Parameter scaling to use for the scan. |
94
|
|
|
is_norm : bool |
95
|
|
|
Whether the parameter represents the flux norm of the model. |
96
|
|
|
""" |
97
|
|
|
|
98
|
|
|
def __init__( |
99
|
|
|
self, |
100
|
|
|
name, |
101
|
|
|
value, |
102
|
|
|
unit="", |
103
|
|
|
scale=1, |
104
|
|
|
min=np.nan, |
105
|
|
|
max=np.nan, |
106
|
|
|
frozen=False, |
107
|
|
|
error=0, |
108
|
|
|
scan_min=None, |
109
|
|
|
scan_max=None, |
110
|
|
|
scan_n_values=11, |
111
|
|
|
scan_n_sigma=2, |
112
|
|
|
scan_values=None, |
113
|
|
|
scale_method="scale10", |
114
|
|
|
interp="lin", |
115
|
|
|
is_norm=False, |
116
|
|
|
): |
117
|
|
|
if not isinstance(name, str): |
118
|
|
|
raise TypeError(f"Name must be string, got '{type(name)}' instead") |
119
|
|
|
|
120
|
|
|
self._name = name |
121
|
|
|
self._link_label_io = None |
122
|
|
|
self.scale = scale |
123
|
|
|
self.min = min |
124
|
|
|
self.max = max |
125
|
|
|
self.frozen = frozen |
126
|
|
|
self._error = error |
127
|
|
|
self._is_norm = is_norm |
128
|
|
|
self._type = None |
129
|
|
|
|
130
|
|
|
# TODO: move this to a setter method that can be called from `__set__` also! |
131
|
|
|
# Having it here is bad: behaviour not clear if Quantity and `unit` is passed. |
132
|
|
|
if isinstance(value, u.Quantity) or isinstance(value, str): |
133
|
|
|
val = u.Quantity(value) |
134
|
|
|
self.value = val.value |
135
|
|
|
self.unit = val.unit |
136
|
|
|
else: |
137
|
|
|
self.factor = value |
138
|
|
|
self.unit = unit |
139
|
|
|
|
140
|
|
|
self.scan_min = scan_min |
141
|
|
|
self.scan_max = scan_max |
142
|
|
|
self.scan_values = scan_values |
143
|
|
|
self.scan_n_values = scan_n_values |
144
|
|
|
self.scan_n_sigma = scan_n_sigma |
145
|
|
|
self.interp = interp |
146
|
|
|
self.scale_method = scale_method |
147
|
|
|
|
148
|
|
|
def __get__(self, instance, owner): |
149
|
|
|
if instance is None: |
150
|
|
|
return self |
151
|
|
|
|
152
|
|
|
par = instance.__dict__[self.name] |
153
|
|
|
par._type = getattr(instance, "type", None) |
154
|
|
|
return par |
155
|
|
|
|
156
|
|
|
def __set__(self, instance, value): |
157
|
|
|
if isinstance(value, Parameter): |
158
|
|
|
instance.__dict__[self.name] = value |
159
|
|
|
else: |
160
|
|
|
par = instance.__dict__[self.name] |
161
|
|
|
raise TypeError(f"Cannot assign {value!r} to parameter {par!r}") |
162
|
|
|
|
163
|
|
|
def __set_name__(self, owner, name): |
164
|
|
|
if not self._name == name: |
165
|
|
|
raise ValueError(f"Expected parameter name '{name}', got {self._name}") |
166
|
|
|
|
167
|
|
|
@property |
168
|
|
|
def is_norm(self): |
169
|
|
|
"""Whether the parameter represents the norm of the model""" |
170
|
|
|
return self._is_norm |
171
|
|
|
|
172
|
|
|
@property |
173
|
|
|
def type(self): |
174
|
|
|
return self._type |
175
|
|
|
|
176
|
|
|
@property |
177
|
|
|
def error(self): |
178
|
|
|
return self._error |
179
|
|
|
|
180
|
|
|
@error.setter |
181
|
|
|
def error(self, value): |
182
|
|
|
self._error = float(u.Quantity(value, unit=self.unit).value) |
183
|
|
|
|
184
|
|
|
@property |
185
|
|
|
def name(self): |
186
|
|
|
"""Name (str).""" |
187
|
|
|
return self._name |
188
|
|
|
|
189
|
|
|
@property |
190
|
|
|
def factor(self): |
191
|
|
|
"""Factor (float).""" |
192
|
|
|
return self._factor |
193
|
|
|
|
194
|
|
|
@factor.setter |
195
|
|
|
def factor(self, val): |
196
|
|
|
self._factor = float(val) |
197
|
|
|
|
198
|
|
|
@property |
199
|
|
|
def scale(self): |
200
|
|
|
"""Scale (float).""" |
201
|
|
|
return self._scale |
202
|
|
|
|
203
|
|
|
@scale.setter |
204
|
|
|
def scale(self, val): |
205
|
|
|
self._scale = float(val) |
206
|
|
|
|
207
|
|
|
@property |
208
|
|
|
def unit(self): |
209
|
|
|
"""Unit (`~astropy.units.Unit`).""" |
210
|
|
|
return self._unit |
211
|
|
|
|
212
|
|
|
@unit.setter |
213
|
|
|
def unit(self, val): |
214
|
|
|
self._unit = u.Unit(val) |
215
|
|
|
|
216
|
|
|
@property |
217
|
|
|
def min(self): |
218
|
|
|
"""Minimum (float).""" |
219
|
|
|
return self._min |
220
|
|
|
|
221
|
|
|
@min.setter |
222
|
|
|
def min(self, val): |
223
|
|
|
"Astropy Table has masked values for NaN. Replacing with np.nan." |
224
|
|
|
if isinstance(val, np.ma.core.MaskedConstant): |
225
|
|
|
self._min = np.nan |
226
|
|
|
else: |
227
|
|
|
self._min = float(val) |
228
|
|
|
|
229
|
|
|
@property |
230
|
|
|
def factor_min(self): |
231
|
|
|
"""Factor min (float). |
232
|
|
|
|
233
|
|
|
This ``factor_min = min / scale`` is for the optimizer interface. |
234
|
|
|
""" |
235
|
|
|
return self.min / self.scale |
236
|
|
|
|
237
|
|
|
@property |
238
|
|
|
def max(self): |
239
|
|
|
"""Maximum (float).""" |
240
|
|
|
return self._max |
241
|
|
|
|
242
|
|
|
@max.setter |
243
|
|
|
def max(self, val): |
244
|
|
|
"Astropy Table has masked values for NaN. Replacing with np.nan." |
245
|
|
|
if isinstance(val, np.ma.core.MaskedConstant): |
246
|
|
|
self._max = np.nan |
247
|
|
|
else: |
248
|
|
|
self._max = float(val) |
249
|
|
|
|
250
|
|
|
@property |
251
|
|
|
def factor_max(self): |
252
|
|
|
"""Factor max (float). |
253
|
|
|
|
254
|
|
|
This ``factor_max = max / scale`` is for the optimizer interface. |
255
|
|
|
""" |
256
|
|
|
return self.max / self.scale |
257
|
|
|
|
258
|
|
|
@property |
259
|
|
|
def scale_method(self): |
260
|
|
|
"""Method used to set ``factor`` and ``scale``""" |
261
|
|
|
return self._scale_method |
262
|
|
|
|
263
|
|
|
@scale_method.setter |
264
|
|
|
def scale_method(self, val): |
265
|
|
|
if val not in ["scale10", "factor1"] and val is not None: |
266
|
|
|
raise ValueError(f"Invalid method: {val}") |
267
|
|
|
self._scale_method = val |
268
|
|
|
|
269
|
|
|
@property |
270
|
|
|
def frozen(self): |
271
|
|
|
"""Frozen? (used in fitting) (bool).""" |
272
|
|
|
return self._frozen |
273
|
|
|
|
274
|
|
|
@frozen.setter |
275
|
|
|
def frozen(self, val): |
276
|
|
|
if val in ["True", "False"]: |
277
|
|
|
val = bool(val) |
278
|
|
|
if not isinstance(val, bool) and not isinstance(val, np.bool_): |
279
|
|
|
raise TypeError(f"Invalid type: {val}, {type(val)}") |
280
|
|
|
self._frozen = val |
281
|
|
|
|
282
|
|
|
@property |
283
|
|
|
def value(self): |
284
|
|
|
"""Value = factor x scale (float).""" |
285
|
|
|
return self._factor * self._scale |
286
|
|
|
|
287
|
|
|
@value.setter |
288
|
|
|
def value(self, val): |
289
|
|
|
self.factor = float(val) / self._scale |
290
|
|
|
|
291
|
|
|
@property |
292
|
|
|
def quantity(self): |
293
|
|
|
"""Value times unit (`~astropy.units.Quantity`).""" |
294
|
|
|
return self.value * self.unit |
295
|
|
|
|
296
|
|
|
@quantity.setter |
297
|
|
|
def quantity(self, val): |
298
|
|
|
val = u.Quantity(val) |
299
|
|
|
|
300
|
|
|
if not val.unit.is_equivalent(self.unit): |
301
|
|
|
raise u.UnitConversionError( |
302
|
|
|
f"Unit must be equivalent to {self.unit} for parameter {self.name}" |
303
|
|
|
) |
304
|
|
|
|
305
|
|
|
self.value = val.value |
306
|
|
|
self.unit = val.unit |
307
|
|
|
|
308
|
|
|
# TODO: possibly allow to set this independently |
309
|
|
|
@property |
310
|
|
|
def conf_min(self): |
311
|
|
|
"""Confidence min value (`float`) |
312
|
|
|
|
313
|
|
|
Returns parameter minimum if defined else the scan_min |
314
|
|
|
""" |
315
|
|
|
if not np.isnan(self.min): |
316
|
|
|
return self.min |
317
|
|
|
else: |
318
|
|
|
return self.scan_min |
319
|
|
|
|
320
|
|
|
# TODO: possibly allow to set this independently |
321
|
|
|
@property |
322
|
|
|
def conf_max(self): |
323
|
|
|
"""Confidence max value (`float`) |
324
|
|
|
|
325
|
|
|
Returns parameter maximum if defined else the scan_max |
326
|
|
|
""" |
327
|
|
|
if not np.isnan(self.max): |
328
|
|
|
return self.max |
329
|
|
|
else: |
330
|
|
|
return self.scan_max |
331
|
|
|
|
332
|
|
|
@property |
333
|
|
|
def scan_min(self): |
334
|
|
|
"""Stat scan min""" |
335
|
|
|
if self._scan_min is None: |
336
|
|
|
return self.value - self.error * self.scan_n_sigma |
337
|
|
|
|
338
|
|
|
return self._scan_min |
339
|
|
|
|
340
|
|
|
@property |
341
|
|
|
def scan_max(self): |
342
|
|
|
"""Stat scan max""" |
343
|
|
|
if self._scan_max is None: |
344
|
|
|
return self.value + self.error * self.scan_n_sigma |
345
|
|
|
|
346
|
|
|
return self._scan_max |
347
|
|
|
|
348
|
|
|
@scan_min.setter |
349
|
|
|
def scan_min(self, value): |
350
|
|
|
"""Stat scan min setter""" |
351
|
|
|
self._scan_min = value |
352
|
|
|
|
353
|
|
|
@scan_max.setter |
354
|
|
|
def scan_max(self, value): |
355
|
|
|
"""Stat scan max setter""" |
356
|
|
|
self._scan_max = value |
357
|
|
|
|
358
|
|
|
@property |
359
|
|
|
def scan_n_sigma(self): |
360
|
|
|
"""Stat scan n sigma""" |
361
|
|
|
return self._scan_n_sigma |
362
|
|
|
|
363
|
|
|
@scan_n_sigma.setter |
364
|
|
|
def scan_n_sigma(self, n_sigma): |
365
|
|
|
"""Stat scan n sigma""" |
366
|
|
|
self._scan_n_sigma = int(n_sigma) |
367
|
|
|
|
368
|
|
|
@property |
369
|
|
|
def scan_values(self): |
370
|
|
|
"""Stat scan values (`~numpy.ndarray`)""" |
371
|
|
|
if self._scan_values is None: |
372
|
|
|
scale = interpolation_scale(self.interp) |
373
|
|
|
parmin, parmax = scale([self.scan_min, self.scan_max]) |
374
|
|
|
values = np.linspace(parmin, parmax, self.scan_n_values) |
375
|
|
|
return scale.inverse(values) |
376
|
|
|
|
377
|
|
|
return self._scan_values |
378
|
|
|
|
379
|
|
|
@scan_values.setter |
380
|
|
|
def scan_values(self, values): |
381
|
|
|
"""Set scan values""" |
382
|
|
|
self._scan_values = values |
383
|
|
|
|
384
|
|
|
def check_limits(self): |
385
|
|
|
"""Emit a warning or error if value is outside the min/max range""" |
386
|
|
|
if not self.frozen: |
387
|
|
|
if (~np.isnan(self.min) and (self.value <= self.min)) or ( |
388
|
|
|
~np.isnan(self.max) and (self.value >= self.max) |
389
|
|
|
): |
390
|
|
|
log.warning( |
391
|
|
|
f"Value {self.value} is outside bounds [{self.min}, {self.max}]" |
392
|
|
|
f" for parameter '{self.name}'" |
393
|
|
|
) |
394
|
|
|
|
395
|
|
|
def __repr__(self): |
396
|
|
|
return ( |
397
|
|
|
f"{self.__class__.__name__}(name={self.name!r}, value={self.value!r}, " |
398
|
|
|
f"factor={self.factor!r}, scale={self.scale!r}, unit={self.unit!r}, " |
399
|
|
|
f"min={self.min!r}, max={self.max!r}, frozen={self.frozen!r}, id={hex(id(self))})" |
400
|
|
|
) |
401
|
|
|
|
402
|
|
|
def copy(self): |
403
|
|
|
"""A deep copy""" |
404
|
|
|
return copy.deepcopy(self) |
405
|
|
|
|
406
|
|
|
def update_from_dict(self, data): |
407
|
|
|
"""Update parameters from a dict. |
408
|
|
|
Protection against changing parameter model, type, name.""" |
409
|
|
|
keys = ["value", "unit", "min", "max", "frozen"] |
410
|
|
|
for k in keys: |
411
|
|
|
setattr(self, k, data[k]) |
412
|
|
|
|
413
|
|
|
def to_dict(self): |
414
|
|
|
"""Convert to dict.""" |
415
|
|
|
output = { |
416
|
|
|
"name": self.name, |
417
|
|
|
"value": self.value, |
418
|
|
|
"unit": self.unit.to_string("fits"), |
419
|
|
|
"error": self.error, |
420
|
|
|
"min": self.min, |
421
|
|
|
"max": self.max, |
422
|
|
|
"frozen": self.frozen, |
423
|
|
|
"interp": self.interp, |
424
|
|
|
"scale_method": self.scale_method, |
425
|
|
|
"is_norm": self.is_norm, |
426
|
|
|
} |
427
|
|
|
|
428
|
|
|
if self._link_label_io is not None: |
429
|
|
|
output["link"] = self._link_label_io |
430
|
|
|
|
431
|
|
|
return output |
432
|
|
|
|
433
|
|
|
def autoscale(self): |
434
|
|
|
"""Autoscale the parameters. |
435
|
|
|
|
436
|
|
|
Set ``factor`` and ``scale`` according to ``scale_method`` attribute |
437
|
|
|
|
438
|
|
|
Available ``scale_method`` |
439
|
|
|
|
440
|
|
|
* ``scale10`` sets ``scale`` to power of 10, |
441
|
|
|
so that abs(factor) is in the range 1 to 10 |
442
|
|
|
* ``factor1`` sets ``factor, scale = 1, value`` |
443
|
|
|
|
444
|
|
|
In both cases the sign of value is stored in ``factor``, |
445
|
|
|
i.e. the ``scale`` is always positive. |
446
|
|
|
If ``scale_method`` is None the scaling is ignored. |
447
|
|
|
|
448
|
|
|
""" |
449
|
|
|
if self.scale_method == "scale10": |
450
|
|
|
value = self.value |
451
|
|
|
if value != 0: |
452
|
|
|
exponent = np.floor(np.log10(np.abs(value))) |
453
|
|
|
scale = np.power(10.0, exponent) |
454
|
|
|
self.factor = value / scale |
455
|
|
|
self.scale = scale |
456
|
|
|
|
457
|
|
|
elif self.scale_method == "factor1": |
458
|
|
|
self.factor, self.scale = 1, self.value |
459
|
|
|
|
460
|
|
|
|
461
|
|
|
class Parameters(collections.abc.Sequence): |
462
|
|
|
"""Parameters container. |
463
|
|
|
|
464
|
|
|
- List of `Parameter` objects. |
465
|
|
|
- Covariance matrix. |
466
|
|
|
|
467
|
|
|
Parameters |
468
|
|
|
---------- |
469
|
|
|
parameters : list of `Parameter` |
470
|
|
|
List of parameters |
471
|
|
|
""" |
472
|
|
|
|
473
|
|
|
def __init__(self, parameters=None): |
474
|
|
|
if parameters is None: |
475
|
|
|
parameters = [] |
476
|
|
|
else: |
477
|
|
|
parameters = list(parameters) |
478
|
|
|
|
479
|
|
|
self._parameters = parameters |
480
|
|
|
|
481
|
|
|
def check_limits(self): |
482
|
|
|
"""Check parameter limits and emit a warning""" |
483
|
|
|
for par in self: |
484
|
|
|
par.check_limits() |
485
|
|
|
|
486
|
|
|
@property |
487
|
|
|
def types(self): |
488
|
|
|
"""Parameter types""" |
489
|
|
|
return [par.type for par in self] |
490
|
|
|
|
491
|
|
|
@property |
492
|
|
|
def min(self): |
493
|
|
|
"""Parameter mins (`numpy.ndarray`).""" |
494
|
|
|
return np.array([_.min for _ in self._parameters], dtype=np.float64) |
495
|
|
|
|
496
|
|
|
@min.setter |
497
|
|
|
def min(self, min_array): |
498
|
|
|
"""Parameter minima (`numpy.ndarray`).""" |
499
|
|
|
if not len(self) == len(min_array): |
500
|
|
|
raise ValueError("Minima must have same length as parameter list") |
501
|
|
|
|
502
|
|
|
for min_, par in zip(min_array, self): |
503
|
|
|
par.min = min_ |
504
|
|
|
|
505
|
|
|
@property |
506
|
|
|
def max(self): |
507
|
|
|
"""Parameter maxima (`numpy.ndarray`).""" |
508
|
|
|
return np.array([_.max for _ in self._parameters], dtype=np.float64) |
509
|
|
|
|
510
|
|
|
@max.setter |
511
|
|
|
def max(self, max_array): |
512
|
|
|
"""Parameter maxima (`numpy.ndarray`).""" |
513
|
|
|
if not len(self) == len(max_array): |
514
|
|
|
raise ValueError("Maxima must have same length as parameter list") |
515
|
|
|
|
516
|
|
|
for max_, par in zip(max_array, self): |
517
|
|
|
par.max = max_ |
518
|
|
|
|
519
|
|
|
@property |
520
|
|
|
def value(self): |
521
|
|
|
"""Parameter values (`numpy.ndarray`).""" |
522
|
|
|
return np.array([_.value for _ in self._parameters], dtype=np.float64) |
523
|
|
|
|
524
|
|
|
@value.setter |
525
|
|
|
def value(self, values): |
526
|
|
|
"""Parameter values (`numpy.ndarray`).""" |
527
|
|
|
if not len(self) == len(values): |
528
|
|
|
raise ValueError("Values must have same length as parameter list") |
529
|
|
|
|
530
|
|
|
for value, par in zip(values, self): |
531
|
|
|
par.value = value |
532
|
|
|
|
533
|
|
|
@classmethod |
534
|
|
|
def from_stack(cls, parameters_list): |
535
|
|
|
"""Create `Parameters` by stacking a list of other `Parameters` objects. |
536
|
|
|
|
537
|
|
|
Parameters |
538
|
|
|
---------- |
539
|
|
|
parameters_list : list of `Parameters` |
540
|
|
|
List of `Parameters` objects |
541
|
|
|
""" |
542
|
|
|
pars = itertools.chain(*parameters_list) |
543
|
|
|
return cls(pars) |
544
|
|
|
|
545
|
|
|
def copy(self): |
546
|
|
|
"""A deep copy""" |
547
|
|
|
return copy.deepcopy(self) |
548
|
|
|
|
549
|
|
|
@property |
550
|
|
|
def free_parameters(self): |
551
|
|
|
"""List of free parameters""" |
552
|
|
|
return self.__class__([par for par in self._parameters if not par.frozen]) |
553
|
|
|
|
554
|
|
|
@property |
555
|
|
|
def unique_parameters(self): |
556
|
|
|
"""Unique parameters (`Parameters`).""" |
557
|
|
|
return self.__class__(dict.fromkeys(self._parameters)) |
558
|
|
|
|
559
|
|
|
@property |
560
|
|
|
def names(self): |
561
|
|
|
"""List of parameter names""" |
562
|
|
|
return [par.name for par in self._parameters] |
563
|
|
|
|
564
|
|
|
def index(self, val): |
565
|
|
|
"""Get position index for a given parameter. |
566
|
|
|
|
567
|
|
|
The input can be a parameter object, parameter name (str) |
568
|
|
|
or if a parameter index (int) is passed in, it is simply returned. |
569
|
|
|
""" |
570
|
|
|
if isinstance(val, int): |
571
|
|
|
return val |
572
|
|
|
elif isinstance(val, Parameter): |
573
|
|
|
return self._parameters.index(val) |
574
|
|
|
elif isinstance(val, str): |
575
|
|
|
for idx, par in enumerate(self._parameters): |
576
|
|
|
if val == par.name: |
577
|
|
|
return idx |
578
|
|
|
raise IndexError(f"No parameter: {val!r}") |
579
|
|
|
else: |
580
|
|
|
raise TypeError(f"Invalid type: {type(val)!r}") |
581
|
|
|
|
582
|
|
|
def __getitem__(self, key): |
583
|
|
|
"""Access parameter by name, index or boolean mask""" |
584
|
|
|
if isinstance(key, np.ndarray) and key.dtype == bool: |
585
|
|
|
return self.__class__(list(np.array(self._parameters)[key])) |
586
|
|
|
else: |
587
|
|
|
idx = self.index(key) |
588
|
|
|
return self._parameters[idx] |
589
|
|
|
|
590
|
|
|
def __len__(self): |
591
|
|
|
return len(self._parameters) |
592
|
|
|
|
593
|
|
|
def __add__(self, other): |
594
|
|
|
if isinstance(other, Parameters): |
595
|
|
|
return Parameters.from_stack([self, other]) |
596
|
|
|
else: |
597
|
|
|
raise TypeError(f"Invalid type: {other!r}") |
598
|
|
|
|
599
|
|
|
def to_dict(self): |
600
|
|
|
data = [] |
601
|
|
|
|
602
|
|
|
for par in self._parameters: |
603
|
|
|
data.append(par.to_dict()) |
604
|
|
|
|
605
|
|
|
return data |
606
|
|
|
|
607
|
|
|
def to_table(self): |
608
|
|
|
"""Convert parameter attributes to `~astropy.table.Table`.""" |
609
|
|
|
rows = [] |
610
|
|
|
for p in self._parameters: |
611
|
|
|
d = p.to_dict() |
612
|
|
|
if "link" not in d: |
613
|
|
|
d["link"] = "" |
614
|
|
|
for key in ["scale_method", "interp"]: |
615
|
|
|
if key in d: |
616
|
|
|
del d[key] |
617
|
|
|
rows.append({**dict(type=p.type), **d}) |
618
|
|
|
table = table_from_row_data(rows) |
619
|
|
|
|
620
|
|
|
table["value"].format = ".4e" |
621
|
|
|
for name in ["error", "min", "max"]: |
622
|
|
|
table[name].format = ".3e" |
623
|
|
|
|
624
|
|
|
return table |
625
|
|
|
|
626
|
|
|
def __eq__(self, other): |
627
|
|
|
all_equal = np.all([p is p_new for p, p_new in zip(self, other)]) |
628
|
|
|
return all_equal and len(self) == len(other) |
629
|
|
|
|
630
|
|
|
@classmethod |
631
|
|
|
def from_dict(cls, data): |
632
|
|
|
parameters = [] |
633
|
|
|
|
634
|
|
|
for par in data: |
635
|
|
|
link_label = par.pop("link", None) |
636
|
|
|
parameter = Parameter(**par) |
637
|
|
|
parameter._link_label_io = link_label |
638
|
|
|
parameters.append(parameter) |
639
|
|
|
|
640
|
|
|
return cls(parameters=parameters) |
641
|
|
|
|
642
|
|
|
def set_parameter_factors(self, factors): |
643
|
|
|
"""Set factor of all parameters. |
644
|
|
|
|
645
|
|
|
Used in the optimizer interface. |
646
|
|
|
""" |
647
|
|
|
idx = 0 |
648
|
|
|
for parameter in self._parameters: |
649
|
|
|
if not parameter.frozen: |
650
|
|
|
parameter.factor = factors[idx] |
651
|
|
|
idx += 1 |
652
|
|
|
|
653
|
|
|
def autoscale(self): |
654
|
|
|
"""Autoscale all parameters. |
655
|
|
|
|
656
|
|
|
See :func:`~gammapy.modeling.Parameter.autoscale` |
657
|
|
|
|
658
|
|
|
""" |
659
|
|
|
for par in self._parameters: |
660
|
|
|
par.autoscale() |
661
|
|
|
|
662
|
|
|
def select( |
663
|
|
|
self, |
664
|
|
|
name=None, |
665
|
|
|
type=None, |
666
|
|
|
frozen=None, |
667
|
|
|
): |
668
|
|
|
"""Create a mask of models, true if all conditions are verified |
669
|
|
|
|
670
|
|
|
Parameters |
671
|
|
|
---------- |
672
|
|
|
name : str or list |
673
|
|
|
Name of the parameter |
674
|
|
|
type : {None, spatial, spectral, temporal} |
675
|
|
|
type of models |
676
|
|
|
frozen : bool |
677
|
|
|
Select frozen parameters if True, exclude them if False. |
678
|
|
|
|
679
|
|
|
Returns |
680
|
|
|
------- |
681
|
|
|
parameters : `Parameters` |
682
|
|
|
Selected parameters |
683
|
|
|
""" |
684
|
|
|
selection = np.ones(len(self), dtype=bool) |
685
|
|
|
|
686
|
|
|
if name and not isinstance(name, list): |
687
|
|
|
name = [name] |
688
|
|
|
|
689
|
|
|
for idx, par in enumerate(self): |
690
|
|
|
if name: |
691
|
|
|
selection[idx] &= np.any([_ == par.name for _ in name]) |
692
|
|
|
|
693
|
|
|
if type: |
694
|
|
|
selection[idx] &= type == par.type |
695
|
|
|
|
696
|
|
|
if frozen is not None: |
697
|
|
|
if frozen: |
698
|
|
|
selection[idx] &= par.frozen |
699
|
|
|
else: |
700
|
|
|
selection[idx] &= ~par.frozen |
701
|
|
|
|
702
|
|
|
return self[selection] |
703
|
|
|
|
704
|
|
|
def freeze_all(self): |
705
|
|
|
"""Freeze all parameters""" |
706
|
|
|
for par in self._parameters: |
707
|
|
|
par.frozen = True |
708
|
|
|
|
709
|
|
|
def unfreeze_all(self): |
710
|
|
|
"""Unfreeze all parameters (even those frozen by default)""" |
711
|
|
|
for par in self._parameters: |
712
|
|
|
par.frozen = False |
713
|
|
|
|
714
|
|
|
def restore_status(self, restore_values=True): |
715
|
|
|
"""Context manager to restore status. |
716
|
|
|
|
717
|
|
|
A copy of the values is made on enter, |
718
|
|
|
and those values are restored on exit. |
719
|
|
|
|
720
|
|
|
Parameters |
721
|
|
|
---------- |
722
|
|
|
restore_values : bool |
723
|
|
|
Restore values if True, otherwise restore only frozen status. |
724
|
|
|
|
725
|
|
|
Examples |
726
|
|
|
-------- |
727
|
|
|
:: |
728
|
|
|
|
729
|
|
|
from gammapy.modeling.models import PowerLawSpectralModel |
730
|
|
|
pwl = PowerLawSpectralModel(index=2) |
731
|
|
|
with pwl.parameters.restore_status(): |
732
|
|
|
pwl.parameters["index"].value = 3 |
733
|
|
|
print(pwl.parameters["index"].value) |
734
|
|
|
""" |
735
|
|
|
return restore_parameters_status(self, restore_values) |
736
|
|
|
|
737
|
|
|
|
738
|
|
|
class restore_parameters_status: |
739
|
|
|
def __init__(self, parameters, restore_values=True): |
740
|
|
|
self.restore_values = restore_values |
741
|
|
|
self._parameters = parameters |
742
|
|
|
self.values = [_.value for _ in parameters] |
743
|
|
|
self.frozen = [_.frozen for _ in parameters] |
744
|
|
|
|
745
|
|
|
def __enter__(self): |
746
|
|
|
pass |
747
|
|
|
|
748
|
|
|
def __exit__(self, type, value, traceback): |
749
|
|
|
for value, par, frozen in zip(self.values, self._parameters, self.frozen): |
750
|
|
|
if self.restore_values: |
751
|
|
|
par.value = value |
752
|
|
|
par.frozen = frozen |
753
|
|
|
|