1 | # Licensed under a 3-clause BSD style license - see LICENSE.rst |
||
2 | import logging |
||
3 | import collections.abc |
||
4 | import copy |
||
5 | from os.path import split |
||
6 | import yaml |
||
7 | import numpy as np |
||
8 | import astropy.units as u |
||
9 | from astropy.table import Table |
||
10 | from astropy.coordinates import SkyCoord |
||
11 | from regions import PointSkyRegion |
||
12 | from gammapy.modeling import Covariance, Parameter, Parameters |
||
13 | from gammapy.utils.scripts import make_name, make_path |
||
14 | from gammapy.maps import RegionGeom, Map |
||
15 | |||
16 | |||
17 | log = logging.getLogger(__name__) |
||
18 | |||
19 | |||
20 | def _set_link(shared_register, model): |
||
21 | for param in model.parameters: |
||
22 | name = param.name |
||
23 | link_label = param._link_label_io |
||
24 | if link_label is not None: |
||
25 | if link_label in shared_register: |
||
26 | new_param = shared_register[link_label] |
||
27 | setattr(model, name, new_param) |
||
28 | else: |
||
29 | shared_register[link_label] = param |
||
30 | return shared_register |
||
31 | |||
32 | def _get_model_class_from_dict(data): |
||
33 | """get a model class from a dict""" |
||
34 | from . import ( |
||
35 | MODEL_REGISTRY, |
||
36 | SPATIAL_MODEL_REGISTRY, |
||
37 | SPECTRAL_MODEL_REGISTRY, |
||
38 | TEMPORAL_MODEL_REGISTRY, |
||
39 | ) |
||
40 | |||
41 | if "type" in data: |
||
42 | cls = MODEL_REGISTRY.get_cls(data["type"]) |
||
43 | elif "spatial" in data: |
||
44 | cls = SPATIAL_MODEL_REGISTRY.get_cls(data["spatial"]["type"]) |
||
45 | elif "spectral" in data: |
||
46 | cls = SPECTRAL_MODEL_REGISTRY.get_cls(data["spectral"]["type"]) |
||
47 | elif "temporal" in data: |
||
48 | cls = TEMPORAL_MODEL_REGISTRY.get_cls(data["temporal"]["type"]) |
||
49 | return cls |
||
0 ignored issues
–
show
introduced
by
Loading history...
|
|||
50 | |||
51 | |||
52 | __all__ = ["Model", "Models", "DatasetModels"] |
||
53 | |||
54 | |||
55 | class ModelBase: |
||
56 | """Model base class.""" |
||
57 | |||
58 | _type = None |
||
59 | |||
60 | def __init__(self, **kwargs): |
||
61 | # Copy default parameters from the class to the instance |
||
62 | default_parameters = self.default_parameters.copy() |
||
63 | |||
64 | for par in default_parameters: |
||
65 | value = kwargs.get(par.name, par) |
||
66 | |||
67 | if not isinstance(value, Parameter): |
||
68 | par.quantity = u.Quantity(value) |
||
69 | else: |
||
70 | par = value |
||
71 | |||
72 | setattr(self, par.name, par) |
||
73 | self._covariance = Covariance(self.parameters) |
||
74 | |||
75 | def __getattribute__(self, name): |
||
76 | value = object.__getattribute__(self, name) |
||
77 | |||
78 | if isinstance(value, Parameter): |
||
79 | return value.__get__(self, None) |
||
80 | |||
81 | return value |
||
82 | |||
83 | @property |
||
84 | def type(self): |
||
85 | return self._type |
||
86 | |||
87 | def __init_subclass__(cls, **kwargs): |
||
88 | # Add parameters list on the model sub-class (not instances) |
||
89 | cls.default_parameters = Parameters( |
||
90 | [_ for _ in cls.__dict__.values() if isinstance(_, Parameter)] |
||
91 | ) |
||
92 | |||
93 | @classmethod |
||
94 | def from_parameters(cls, parameters, **kwargs): |
||
95 | """Create model from parameter list |
||
96 | |||
97 | Parameters |
||
98 | ---------- |
||
99 | parameters : `Parameters` |
||
100 | Parameters for init |
||
101 | |||
102 | Returns |
||
103 | ------- |
||
104 | model : `Model` |
||
105 | Model instance |
||
106 | """ |
||
107 | for par in parameters: |
||
108 | kwargs[par.name] = par |
||
109 | return cls(**kwargs) |
||
110 | |||
111 | def _check_covariance(self): |
||
112 | if not self.parameters == self._covariance.parameters: |
||
113 | self._covariance = Covariance(self.parameters) |
||
114 | |||
115 | @property |
||
116 | def covariance(self): |
||
117 | self._check_covariance() |
||
118 | for par in self.parameters: |
||
119 | pars = Parameters([par]) |
||
120 | error = np.nan_to_num(par.error ** 2, nan=1) |
||
121 | covar = Covariance(pars, data=[[error]]) |
||
122 | self._covariance.set_subcovariance(covar) |
||
123 | |||
124 | return self._covariance |
||
125 | |||
126 | @covariance.setter |
||
127 | def covariance(self, covariance): |
||
128 | self._check_covariance() |
||
129 | self._covariance.data = covariance |
||
130 | |||
131 | for par in self.parameters: |
||
132 | pars = Parameters([par]) |
||
133 | variance = self._covariance.get_subcovariance(pars) |
||
134 | par.error = np.sqrt(variance) |
||
135 | |||
136 | @property |
||
137 | def parameters(self): |
||
138 | """Parameters (`~gammapy.modeling.Parameters`)""" |
||
139 | return Parameters( |
||
140 | [getattr(self, name) for name in self.default_parameters.names] |
||
141 | ) |
||
142 | |||
143 | def copy(self): |
||
144 | """A deep copy.""" |
||
145 | return copy.deepcopy(self) |
||
146 | |||
147 | def to_dict(self, full_output=False): |
||
148 | """Create dict for YAML serialisation""" |
||
149 | tag = self.tag[0] if isinstance(self.tag, list) else self.tag |
||
150 | params = self.parameters.to_dict() |
||
151 | |||
152 | if not full_output: |
||
153 | for par, par_default in zip(params, self.default_parameters): |
||
154 | init = par_default.to_dict() |
||
155 | for item in ["min", "max", "error", "interp", "scale_method"]: |
||
156 | default = init[item] |
||
157 | |||
158 | if par[item] == default or np.isnan(default): |
||
159 | del par[item] |
||
160 | |||
161 | if not par["frozen"]: |
||
162 | del par["frozen"] |
||
163 | |||
164 | if init["unit"] == "": |
||
165 | del par["unit"] |
||
166 | data = {"type": tag, "parameters": params} |
||
167 | if self._type is None: |
||
168 | return data |
||
169 | else: |
||
170 | return {self._type: data} |
||
171 | |||
172 | @classmethod |
||
173 | def from_dict(cls, data): |
||
174 | kwargs = {} |
||
175 | |||
176 | par_data = [] |
||
177 | key0 = next(iter(data)) |
||
178 | if key0 in ["spatial", "temporal", "spectral"]: |
||
179 | data = data[key0] |
||
180 | if data["type"] not in cls.tag: |
||
181 | raise ValueError( |
||
182 | f"Invalid model type {data['type']} for Class {cls.__name__}" |
||
183 | ) |
||
184 | |||
185 | input_names = [_["name"] for _ in data["parameters"]] |
||
186 | |||
187 | for par in cls.default_parameters: |
||
188 | par_dict = par.to_dict() |
||
189 | try: |
||
190 | index = input_names.index(par_dict["name"]) |
||
191 | par_dict.update(data["parameters"][index]) |
||
192 | except ValueError: |
||
193 | log.warning( |
||
194 | f"Parameter {par_dict['name']} not defined. Using default value: {par_dict['value']} {par_dict['unit']}" |
||
195 | ) |
||
196 | par_data.append(par_dict) |
||
197 | |||
198 | parameters = Parameters.from_dict(par_data) |
||
199 | |||
200 | # TODO: this is a special case for spatial models, maybe better move to `SpatialModel` base class |
||
201 | if "frame" in data: |
||
202 | kwargs["frame"] = data["frame"] |
||
203 | |||
204 | return cls.from_parameters(parameters, **kwargs) |
||
205 | |||
206 | def __str__(self): |
||
207 | string = f"{self.__class__.__name__}\n" |
||
208 | if len(self.parameters) > 0: |
||
209 | string += f"\n{self.parameters.to_table()}" |
||
210 | return string |
||
211 | |||
212 | @property |
||
213 | def frozen(self): |
||
214 | """Frozen status of a model, True if all parameters are frozen """ |
||
215 | return np.all([p.frozen for p in self.parameters]) |
||
216 | |||
217 | def freeze(self): |
||
218 | """Freeze all parameters""" |
||
219 | self.parameters.freeze_all() |
||
220 | |||
221 | def unfreeze(self): |
||
222 | """Restore parameters frozen status to default""" |
||
223 | for p, default in zip(self.parameters, self.default_parameters): |
||
224 | p.frozen = default.frozen |
||
225 | |||
226 | def reassign(self, datasets_names, new_datasets_names): |
||
227 | """Reassign a model from one dataset to another |
||
228 | |||
229 | Parameters |
||
230 | ---------- |
||
231 | datasets_names : str or list |
||
232 | Name of the datasets where the model is currently defined |
||
233 | new_datasets_names : str or list |
||
234 | Name of the datasets where the model should be defined instead. |
||
235 | If multiple names are given the two list must have the save length, |
||
236 | as the reassignment is element-wise. |
||
237 | |||
238 | Returns |
||
239 | ------- |
||
240 | model : `Model` |
||
241 | Reassigned model. |
||
242 | |||
243 | """ |
||
244 | model = self.copy(name=self.name) |
||
245 | |||
246 | if not isinstance(datasets_names, list): |
||
247 | datasets_names = [datasets_names] |
||
248 | |||
249 | if not isinstance(new_datasets_names, list): |
||
250 | new_datasets_names = [new_datasets_names] |
||
251 | |||
252 | if isinstance(model.datasets_names, str): |
||
253 | model.datasets_names = [model.datasets_names] |
||
254 | |||
255 | if getattr(model, "datasets_names", None): |
||
256 | for name, name_new in zip(datasets_names, new_datasets_names): |
||
257 | model.datasets_names = [ |
||
258 | _.replace(name, name_new) for _ in model.datasets_names |
||
259 | ] |
||
260 | |||
261 | return model |
||
262 | |||
263 | |||
264 | class Model: |
||
265 | """Model class that contains only methods to create a model listed in the registries.""" |
||
266 | |||
267 | @staticmethod |
||
268 | def create(tag, model_type=None, *args, **kwargs): |
||
269 | """Create a model instance. |
||
270 | |||
271 | Examples |
||
272 | -------- |
||
273 | >>> from gammapy.modeling.models import Model |
||
274 | >>> spectral_model = Model.create("pl-2", model_type="spectral", amplitude="1e-10 cm-2 s-1", index=3) |
||
275 | >>> type(spectral_model) |
||
276 | <class 'gammapy.modeling.models.spectral.PowerLaw2SpectralModel'> |
||
277 | """ |
||
278 | |||
279 | data = {"type":tag} |
||
280 | if model_type is not None: |
||
281 | data = {model_type:data} |
||
282 | |||
283 | cls = _get_model_class_from_dict(data) |
||
284 | return cls(*args, **kwargs) |
||
285 | |||
286 | @staticmethod |
||
287 | def from_dict(data): |
||
288 | """Create a model instance from a dict""" |
||
289 | |||
290 | cls = _get_model_class_from_dict(data) |
||
291 | return cls.from_dict(data) |
||
292 | |||
293 | |||
294 | |||
295 | class DatasetModels(collections.abc.Sequence): |
||
296 | """Immutable models container |
||
297 | |||
298 | Parameters |
||
299 | ---------- |
||
300 | models : `SkyModel`, list of `SkyModel` or `Models` |
||
301 | Sky models |
||
302 | """ |
||
303 | |||
304 | def __init__(self, models=None): |
||
305 | if models is None: |
||
306 | models = [] |
||
307 | |||
308 | if isinstance(models, (Models, DatasetModels)): |
||
309 | models = models._models |
||
310 | elif isinstance(models, ModelBase): |
||
311 | models = [models] |
||
312 | elif not isinstance(models, list): |
||
313 | raise TypeError(f"Invalid type: {models!r}") |
||
314 | |||
315 | unique_names = [] |
||
316 | for model in models: |
||
317 | if model.name in unique_names: |
||
318 | raise (ValueError("Model names must be unique")) |
||
319 | unique_names.append(model.name) |
||
320 | |||
321 | self._models = models |
||
322 | self._covar_file = None |
||
323 | self._covariance = Covariance(self.parameters) |
||
324 | |||
325 | def _check_covariance(self): |
||
326 | if not self.parameters == self._covariance.parameters: |
||
327 | self._covariance = Covariance.from_stack( |
||
328 | [model.covariance for model in self._models] |
||
329 | ) |
||
330 | |||
331 | @property |
||
332 | def covariance(self): |
||
333 | self._check_covariance() |
||
334 | |||
335 | for model in self._models: |
||
336 | self._covariance.set_subcovariance(model.covariance) |
||
337 | |||
338 | return self._covariance |
||
339 | |||
340 | @covariance.setter |
||
341 | def covariance(self, covariance): |
||
342 | self._check_covariance() |
||
343 | self._covariance.data = covariance |
||
344 | |||
345 | for model in self._models: |
||
346 | subcovar = self._covariance.get_subcovariance(model.covariance.parameters) |
||
347 | model.covariance = subcovar |
||
348 | |||
349 | @property |
||
350 | def parameters(self): |
||
351 | return Parameters.from_stack([_.parameters for _ in self._models]) |
||
352 | |||
353 | @property |
||
354 | def parameters_unique_names(self): |
||
355 | """List of unique parameter names as model_name.par_type.par_name""" |
||
356 | names = [] |
||
357 | for model in self: |
||
358 | for par in model.parameters: |
||
359 | components = [model.name, par.type, par.name] |
||
360 | name = ".".join(components) |
||
361 | names.append(name) |
||
362 | |||
363 | return names |
||
364 | |||
365 | @property |
||
366 | def names(self): |
||
367 | return [m.name for m in self._models] |
||
368 | |||
369 | @classmethod |
||
370 | def read(cls, filename): |
||
371 | """Read from YAML file.""" |
||
372 | yaml_str = make_path(filename).read_text() |
||
373 | path, filename = split(filename) |
||
374 | return cls.from_yaml(yaml_str, path=path) |
||
375 | |||
376 | @classmethod |
||
377 | def from_yaml(cls, yaml_str, path=""): |
||
378 | """Create from YAML string.""" |
||
379 | data = yaml.safe_load(yaml_str) |
||
380 | return cls.from_dict(data, path=path) |
||
381 | |||
382 | @classmethod |
||
383 | def from_dict(cls, data, path=""): |
||
384 | """Create from dict.""" |
||
385 | from . import MODEL_REGISTRY, SkyModel |
||
386 | |||
387 | models = [] |
||
388 | |||
389 | for component in data["components"]: |
||
390 | model_cls = MODEL_REGISTRY.get_cls(component["type"]) |
||
391 | model = model_cls.from_dict(component) |
||
392 | models.append(model) |
||
393 | |||
394 | models = cls(models) |
||
395 | |||
396 | if "covariance" in data: |
||
397 | filename = data["covariance"] |
||
398 | path = make_path(path) |
||
399 | if not (path / filename).exists(): |
||
400 | path, filename = split(filename) |
||
401 | |||
402 | models.read_covariance(path, filename, format="ascii.fixed_width") |
||
403 | |||
404 | shared_register = {} |
||
405 | for model in models: |
||
406 | if isinstance(model, SkyModel): |
||
407 | submodels = [ |
||
408 | model.spectral_model, |
||
409 | model.spatial_model, |
||
410 | model.temporal_model, |
||
411 | ] |
||
412 | for submodel in submodels: |
||
413 | if submodel is not None: |
||
414 | shared_register = _set_link(shared_register, submodel) |
||
415 | else: |
||
416 | shared_register = _set_link(shared_register, model) |
||
417 | return models |
||
418 | |||
419 | def write( |
||
420 | self, |
||
421 | path, |
||
422 | overwrite=False, |
||
423 | full_output=False, |
||
424 | overwrite_templates=False, |
||
425 | write_covariance=True, |
||
426 | ): |
||
427 | """Write to YAML file. |
||
428 | |||
429 | Parameters |
||
430 | ---------- |
||
431 | path : `pathlib.Path` or str |
||
432 | path to write files |
||
433 | overwrite : bool |
||
434 | overwrite YAML files |
||
435 | overwrite_templates : bool |
||
436 | overwrite templates FITS files |
||
437 | write_covariance : bool |
||
438 | save covariance or not |
||
439 | """ |
||
440 | base_path, _ = split(path) |
||
441 | path = make_path(path) |
||
442 | base_path = make_path(base_path) |
||
443 | |||
444 | if path.exists() and not overwrite: |
||
445 | raise IOError(f"File exists already: {path}") |
||
446 | |||
447 | if ( |
||
448 | write_covariance |
||
449 | and self.covariance is not None |
||
450 | and len(self.parameters) != 0 |
||
451 | ): |
||
452 | filecovar = path.stem + "_covariance.dat" |
||
453 | kwargs = dict( |
||
454 | format="ascii.fixed_width", delimiter="|", overwrite=overwrite |
||
455 | ) |
||
456 | self.write_covariance(base_path / filecovar, **kwargs) |
||
457 | self._covar_file = filecovar |
||
458 | |||
459 | path.write_text(self.to_yaml(full_output, overwrite_templates)) |
||
460 | |||
461 | def to_yaml(self, full_output=False, overwrite_templates=False): |
||
462 | """Convert to YAML string.""" |
||
463 | data = self.to_dict(full_output, overwrite_templates) |
||
464 | return yaml.dump( |
||
465 | data, sort_keys=False, indent=4, width=80, default_flow_style=False |
||
466 | ) |
||
467 | |||
468 | def to_dict(self, full_output=False, overwrite_templates=False): |
||
469 | """Convert to dict.""" |
||
470 | # update linked parameters labels |
||
471 | params_list = [] |
||
472 | params_shared = [] |
||
473 | for param in self.parameters: |
||
474 | if param not in params_list: |
||
475 | params_list.append(param) |
||
476 | params_list.append(param) |
||
477 | elif param not in params_shared: |
||
478 | params_shared.append(param) |
||
479 | for param in params_shared: |
||
480 | param._link_label_io = param.name + "@" + make_name() |
||
481 | |||
482 | models_data = [] |
||
483 | for model in self._models: |
||
484 | model_data = model.to_dict(full_output) |
||
485 | models_data.append(model_data) |
||
486 | if ( |
||
487 | hasattr(model, "spatial_model") |
||
488 | and model.spatial_model is not None |
||
489 | and "template" in model.spatial_model.tag |
||
490 | ): |
||
491 | model.spatial_model.write(overwrite=overwrite_templates) |
||
492 | |||
493 | if self._covar_file is not None: |
||
494 | return { |
||
495 | "components": models_data, |
||
496 | "covariance": str(self._covar_file), |
||
497 | } |
||
498 | else: |
||
499 | return {"components": models_data} |
||
500 | |||
501 | def to_parameters_table(self): |
||
502 | """Convert Models parameters to an astropy Table.""" |
||
503 | table = self.parameters.to_table() |
||
504 | # Warning: splitting of parameters will break is source name has a "." in its name. |
||
505 | model_name = [name.split(".")[0] for name in self.parameters_unique_names] |
||
506 | table.add_column(model_name, name="model", index=0) |
||
507 | self._table_cached = table |
||
508 | return table |
||
509 | |||
510 | def update_parameters_from_table(self, t): |
||
511 | """Update Models from an astropy Table.""" |
||
512 | parameters_dict = [dict(zip(t.colnames, row)) for row in t] |
||
513 | for k, data in enumerate(parameters_dict): |
||
514 | self.parameters[k].update_from_dict(data) |
||
515 | |||
516 | def read_covariance(self, path, filename="_covariance.dat", **kwargs): |
||
517 | """Read covariance data from file |
||
518 | |||
519 | Parameters |
||
520 | ---------- |
||
521 | filename : str |
||
522 | Filename |
||
523 | **kwargs : dict |
||
524 | Keyword arguments passed to `~astropy.table.Table.read` |
||
525 | |||
526 | """ |
||
527 | path = make_path(path) |
||
528 | filepath = str(path / filename) |
||
529 | t = Table.read(filepath, **kwargs) |
||
530 | t.remove_column("Parameters") |
||
531 | arr = np.array(t) |
||
532 | data = arr.view(float).reshape(arr.shape + (-1,)) |
||
533 | self.covariance = data |
||
534 | self._covar_file = filename |
||
535 | |||
536 | def write_covariance(self, filename, **kwargs): |
||
537 | """Write covariance to file |
||
538 | |||
539 | Parameters |
||
540 | ---------- |
||
541 | filename : str |
||
542 | Filename |
||
543 | **kwargs : dict |
||
544 | Keyword arguments passed to `~astropy.table.Table.write` |
||
545 | |||
546 | """ |
||
547 | names = self.parameters_unique_names |
||
548 | table = Table() |
||
549 | table["Parameters"] = names |
||
550 | |||
551 | for idx, name in enumerate(names): |
||
552 | values = self.covariance.data[idx] |
||
553 | table[name] = values |
||
554 | |||
555 | table.write(make_path(filename), **kwargs) |
||
556 | |||
557 | def __str__(self): |
||
558 | str_ = f"{self.__class__.__name__}\n\n" |
||
559 | |||
560 | for idx, model in enumerate(self): |
||
561 | str_ += f"Component {idx}: " |
||
562 | str_ += str(model) |
||
563 | |||
564 | return str_.expandtabs(tabsize=2) |
||
565 | |||
566 | def __add__(self, other): |
||
567 | if isinstance(other, (Models, list)): |
||
568 | return Models([*self, *other]) |
||
569 | elif isinstance(other, ModelBase): |
||
570 | if other.name in self.names: |
||
571 | raise (ValueError("Model names must be unique")) |
||
572 | return Models([*self, other]) |
||
573 | else: |
||
574 | raise TypeError(f"Invalid type: {other!r}") |
||
575 | |||
576 | def __getitem__(self, key): |
||
577 | if isinstance(key, np.ndarray) and key.dtype == bool: |
||
578 | return self.__class__(list(np.array(self._models)[key])) |
||
579 | else: |
||
580 | return self._models[self.index(key)] |
||
581 | |||
582 | def index(self, key): |
||
583 | if isinstance(key, (int, slice)): |
||
584 | return key |
||
585 | elif isinstance(key, str): |
||
586 | return self.names.index(key) |
||
587 | elif isinstance(key, ModelBase): |
||
588 | return self._models.index(key) |
||
589 | else: |
||
590 | raise TypeError(f"Invalid type: {type(key)!r}") |
||
591 | |||
592 | def __len__(self): |
||
593 | return len(self._models) |
||
594 | |||
595 | def _ipython_key_completions_(self): |
||
596 | return self.names |
||
597 | |||
598 | def copy(self): |
||
599 | """A deep copy.""" |
||
600 | return copy.deepcopy(self) |
||
601 | |||
602 | def select( |
||
603 | self, |
||
604 | name_substring=None, |
||
605 | datasets_names=None, |
||
606 | tag=None, |
||
607 | model_type=None, |
||
608 | frozen=None, |
||
609 | ): |
||
610 | """Select models that meet all specified conditions |
||
611 | |||
612 | Parameters |
||
613 | ---------- |
||
614 | |||
615 | name_substring : str |
||
616 | Substring contained in the model name |
||
617 | datasets_names : str or list |
||
618 | Name of the dataset |
||
619 | tag : str or list |
||
620 | Model tag |
||
621 | model_type : {None, spatial, spectral} |
||
622 | Type of model, used together with "tag", if the tag is not unique. |
||
623 | frozen : bool |
||
624 | Select models with all parameters frozen if True, exclude them if False. |
||
625 | |||
626 | Returns |
||
627 | ------- |
||
628 | models : `DatasetModels` |
||
629 | Selected models |
||
630 | """ |
||
631 | mask = self.selection_mask( |
||
632 | name_substring, datasets_names, tag, model_type, frozen |
||
633 | ) |
||
634 | return self[mask] |
||
635 | |||
636 | def selection_mask( |
||
637 | self, |
||
638 | name_substring=None, |
||
639 | datasets_names=None, |
||
640 | tag=None, |
||
641 | model_type=None, |
||
642 | frozen=None, |
||
643 | ): |
||
644 | """Create a mask of models, that meet all specified conditions |
||
645 | |||
646 | Parameters |
||
647 | ---------- |
||
648 | name_substring : str |
||
649 | Substring contained in the model name |
||
650 | datasets_names : str or list of str |
||
651 | Name of the dataset |
||
652 | tag : str or list of str |
||
653 | Model tag |
||
654 | model_type : {None, spatial, spectral} |
||
655 | Type of model, used together with "tag", if the tag is not unique. |
||
656 | frozen : bool |
||
657 | Select models with all parameters frozen if True, exclude them if False. |
||
658 | |||
659 | Returns |
||
660 | ------- |
||
661 | mask : `numpy.array` |
||
662 | Boolean mask, True for selected models |
||
663 | """ |
||
664 | selection = np.ones(len(self), dtype=bool) |
||
665 | |||
666 | if tag and not isinstance(tag, list): |
||
667 | tag = [tag] |
||
668 | |||
669 | if datasets_names and not isinstance(datasets_names, list): |
||
670 | datasets_names = [datasets_names] |
||
671 | |||
672 | for idx, model in enumerate(self): |
||
673 | if name_substring: |
||
674 | selection[idx] &= name_substring in model.name |
||
675 | |||
676 | if datasets_names: |
||
677 | selection[idx] &= model.datasets_names is None or np.any( |
||
678 | [name in model.datasets_names for name in datasets_names] |
||
679 | ) |
||
680 | |||
681 | if tag: |
||
682 | if model_type is None: |
||
683 | sub_model = model |
||
684 | else: |
||
685 | sub_model = getattr(model, f"{model_type}_model", None) |
||
686 | |||
687 | if sub_model: |
||
688 | selection[idx] &= np.any([t in sub_model.tag for t in tag]) |
||
689 | else: |
||
690 | selection[idx] &= False |
||
691 | |||
692 | if frozen is not None: |
||
693 | if frozen: |
||
694 | selection[idx] &= model.frozen |
||
695 | else: |
||
696 | selection[idx] &= ~model.frozen |
||
697 | |||
698 | return np.array(selection, dtype=bool) |
||
699 | |||
700 | def select_mask(self, mask, margin="0 deg", use_evaluation_region=True): |
||
701 | """Check if sky models contribute within a mask map. |
||
702 | |||
703 | Parameters |
||
704 | ---------- |
||
705 | mask : `~gammapy.maps.WcsNDMap` of boolean type |
||
706 | Map containing a boolean mask |
||
707 | margin : `~astropy.unit.Quantity` |
||
708 | Add a margin in degree to the source evaluation radius. |
||
709 | Used to take into account PSF width. |
||
710 | use_evaluation_region : bool |
||
711 | Account for the extension of the model or not. The default is True. |
||
712 | |||
713 | Returns |
||
714 | ------- |
||
715 | models : `DatasetModels` |
||
716 | Selected models contributing inside the region where mask==True |
||
717 | """ |
||
718 | models = [] |
||
719 | |||
720 | if not mask.geom.is_image: |
||
721 | mask = mask.reduce_over_axes(func=np.logical_or) |
||
722 | |||
723 | for model in self.select(tag="sky-model"): |
||
724 | if use_evaluation_region: |
||
725 | contributes = model.contributes(mask=mask, margin=margin) |
||
726 | else: |
||
727 | contributes = mask.get_by_coord(model.position, fill_value=0) |
||
728 | |||
729 | if np.any(contributes): |
||
730 | models.append(model) |
||
731 | |||
732 | return self.__class__(models=models) |
||
733 | |||
734 | def select_region(self, regions, wcs=None): |
||
735 | """Select sky models with center position contained within a given region |
||
736 | |||
737 | Parameters |
||
738 | ---------- |
||
739 | regions : str, `~regions.Region` or list of `~regions.Region` |
||
740 | Region or list of regions (pixel or sky regions accepted). |
||
741 | A region can be defined as a string ind DS9 format as well. |
||
742 | See http://ds9.si.edu/doc/ref/region.html for details. |
||
743 | wcs : `~astropy.wcs.WCS` |
||
744 | World coordinate system transformation |
||
745 | |||
746 | Returns |
||
747 | ------- |
||
748 | models : `DatasetModels` |
||
749 | Selected models |
||
750 | """ |
||
751 | geom = RegionGeom.from_regions(regions, wcs=wcs) |
||
752 | |||
753 | models = [] |
||
754 | |||
755 | for model in self.select(tag="sky-model"): |
||
756 | if geom.contains(model.position): |
||
757 | models.append(model) |
||
758 | |||
759 | return self.__class__(models=models) |
||
760 | |||
761 | def restore_status(self, restore_values=True): |
||
762 | """Context manager to restore status. |
||
763 | |||
764 | A copy of the values is made on enter, |
||
765 | and those values are restored on exit. |
||
766 | |||
767 | Parameters |
||
768 | ---------- |
||
769 | restore_values : bool |
||
770 | Restore values if True, |
||
771 | otherwise restore only frozen status and covariance matrix. |
||
772 | |||
773 | """ |
||
774 | return restore_models_status(self, restore_values) |
||
775 | |||
776 | def set_parameters_bounds( |
||
777 | self, tag, model_type, parameters_names, min=None, max=None, value=None |
||
778 | ): |
||
779 | """Set bounds for the selected models types and parameters names |
||
780 | |||
781 | Parameters |
||
782 | ---------- |
||
783 | tag : str or list |
||
784 | tag of the models |
||
785 | model_type : {"spatial", "spectral"} |
||
786 | type of models |
||
787 | parameters_names : str or list |
||
788 | parameters names |
||
789 | min : float |
||
790 | min value |
||
791 | max : float |
||
792 | max value |
||
793 | value : float |
||
794 | init value |
||
795 | """ |
||
796 | |||
797 | models = self.select(tag=tag, model_type=model_type) |
||
798 | parameters = models.parameters.select(name=parameters_names, type=model_type) |
||
799 | n = len(parameters) |
||
800 | |||
801 | if min is not None: |
||
802 | parameters.min = np.ones(n) * min |
||
803 | if max is not None: |
||
804 | parameters.max = np.ones(n) * max |
||
805 | if value is not None: |
||
806 | parameters.value = np.ones(n) * value |
||
807 | |||
808 | def freeze(self, model_type=None): |
||
809 | """Freeze parameters depending on model type |
||
810 | |||
811 | Parameters |
||
812 | ---------- |
||
813 | model_type : {None, "spatial", "spectral"} |
||
814 | freeze all parameters or only spatial or only spectral |
||
815 | """ |
||
816 | |||
817 | for m in self: |
||
818 | m.freeze(model_type) |
||
819 | |||
820 | def unfreeze(self, model_type=None): |
||
821 | """Restore parameters frozen status to default depending on model type |
||
822 | |||
823 | Parameters |
||
824 | ---------- |
||
825 | model_type : {None, "spatial", "spectral"} |
||
826 | restore frozen status to default for all parameters or only spatial or only spectral |
||
827 | """ |
||
828 | |||
829 | for m in self: |
||
830 | m.unfreeze(model_type) |
||
831 | |||
832 | @property |
||
833 | def frozen(self): |
||
834 | """Boolean mask, True if all parameters of a given model are frozen""" |
||
835 | return np.all([m.frozen for m in self]) |
||
836 | |||
837 | def reassign(self, dataset_name, new_dataset_name): |
||
838 | """Reassign a model from one dataset to another |
||
839 | |||
840 | Parameters |
||
841 | ---------- |
||
842 | dataset_name : str or list |
||
843 | Name of the datasets where the model is currently defined |
||
844 | new_dataset_name : str or list |
||
845 | Name of the datasets where the model should be defined instead. |
||
846 | If multiple names are given the two list must have the save length, |
||
847 | as the reassignment is element-wise. |
||
848 | """ |
||
849 | models = [m.reassign(dataset_name, new_dataset_name) for m in self] |
||
850 | return self.__class__(models) |
||
851 | |||
852 | def to_template_sky_model(self, geom, spectral_model=None, name=None): |
||
853 | """Merge a list of models into a single `~gammapy.modeling.models.SkyModel` |
||
854 | |||
855 | Parameters |
||
856 | ---------- |
||
857 | spectral_model : `~gammapy.modeling.models.SpectralModel` |
||
858 | One of the NormSpectralMdel |
||
859 | name : str |
||
860 | Name of the new model |
||
861 | |||
862 | """ |
||
863 | from . import SkyModel, TemplateSpatialModel, PowerLawNormSpectralModel |
||
864 | |||
865 | unit = u.Unit("1 / (cm2 s sr TeV)") |
||
866 | map_ = Map.from_geom(geom, unit=unit) |
||
867 | for m in self: |
||
868 | map_ += m.evaluate_geom(geom).to(unit) |
||
869 | spatial_model = TemplateSpatialModel(map_, normalize=False) |
||
870 | if spectral_model is None: |
||
871 | spectral_model = PowerLawNormSpectralModel() |
||
872 | return SkyModel( |
||
873 | spectral_model=spectral_model, spatial_model=spatial_model, name=name |
||
874 | ) |
||
875 | |||
876 | @property |
||
877 | def positions(self): |
||
878 | """Positions of the models (`SkyCoord`)""" |
||
879 | positions = [] |
||
880 | |||
881 | for model in self.select(tag="sky-model"): |
||
882 | if model.position: |
||
883 | positions.append(model.position) |
||
884 | else: |
||
885 | log.warning( |
||
886 | f"Skipping model {model.name} - no spatial component present" |
||
887 | ) |
||
888 | |||
889 | return SkyCoord(positions) |
||
890 | |||
891 | def to_regions(self): |
||
892 | """Returns a list of the regions for the spatial models |
||
893 | |||
894 | Returns |
||
895 | ------- |
||
896 | regions: list of `~regions.SkyRegion` |
||
897 | Regions |
||
898 | """ |
||
899 | regions = [] |
||
900 | |||
901 | for model in self.select(tag="sky-model"): |
||
902 | try: |
||
903 | region = model.spatial_model.to_region() |
||
904 | regions.append(region) |
||
905 | except AttributeError: |
||
906 | log.warning( |
||
907 | f"Skipping model {model.name} - no spatial component present" |
||
908 | ) |
||
909 | return regions |
||
910 | |||
911 | @property |
||
912 | def wcs_geom(self): |
||
913 | """Minimum WCS geom in which all the models are contained """ |
||
914 | regions = self.to_regions() |
||
915 | try: |
||
916 | return RegionGeom.from_regions(regions).to_wcs_geom() |
||
917 | except IndexError: |
||
918 | log.error("No spatial component in any model. Geom not defined") |
||
919 | |||
920 | def plot_regions(self, ax=None, kwargs_point=None, path_effect=None, **kwargs): |
||
921 | """ Plot extent of the spatial models on a given wcs axis |
||
922 | |||
923 | Parameters |
||
924 | ---------- |
||
925 | ax : `~astropy.visualization.WCSAxes` |
||
926 | Axes to plot on. If no axes are given, an all-sky wcs |
||
927 | is chosen using a CAR projection |
||
928 | kwargs_point : dict |
||
929 | Keyword arguments passed to `~matplotlib.lines.Line2D` for plotting |
||
930 | of point sources |
||
931 | path_effect : `~matplotlib.patheffects.PathEffect` |
||
932 | Path effect applied to artists and lines. |
||
933 | **kwargs : dict |
||
934 | Keyword arguments passed to `~matplotlib.artists.Artist` |
||
935 | |||
936 | |||
937 | Returns |
||
938 | ------- |
||
939 | ax : `~astropy.visualization.WcsAxes` |
||
940 | WCS axes |
||
941 | """ |
||
942 | from astropy.visualization.wcsaxes import WCSAxes |
||
943 | |||
944 | kwargs_point = kwargs_point or {} |
||
945 | |||
946 | if ax is None or not isinstance(ax, WCSAxes): |
||
947 | ax = Map.from_geom(self.wcs_geom).plot() |
||
948 | |||
949 | kwargs.setdefault("color", "tab:blue") |
||
950 | kwargs.setdefault("fc", "None") |
||
951 | kwargs_point.setdefault("marker", "*") |
||
952 | kwargs_point.setdefault("markersize", 10) |
||
953 | kwargs_point.setdefault("markeredgecolor", "None") |
||
954 | kwargs_point.setdefault("color", kwargs["color"]) |
||
955 | |||
956 | for region in self.to_regions(): |
||
957 | if isinstance(region, PointSkyRegion): |
||
958 | artist = region.to_pixel(ax.wcs).as_artist(**kwargs_point) |
||
959 | else: |
||
960 | artist = region.to_pixel(ax.wcs).as_artist(**kwargs) |
||
961 | |||
962 | if path_effect: |
||
963 | artist.set_path_effects([path_effect]) |
||
964 | |||
965 | ax.add_artist(artist) |
||
966 | |||
967 | return ax |
||
968 | |||
969 | def plot_positions(self, ax=None, **kwargs): |
||
970 | """"Plot the centers of the spatial models on a given wcs axis |
||
971 | |||
972 | Parameters |
||
973 | ---------- |
||
974 | ax : `~astropy.visualization.WCSAxes` |
||
975 | Axes to plot on. If no axes are given, an all-sky wcs |
||
976 | is chosen using a CAR projection |
||
977 | **kwargs : dict |
||
978 | Keyword arguments passed to `~matplotlib.pyplot.scatter` |
||
979 | |||
980 | |||
981 | Returns |
||
982 | ------- |
||
983 | ax : `~astropy.visualization.WcsAxes` |
||
984 | Wcs axes |
||
985 | """ |
||
986 | from astropy.visualization.wcsaxes import WCSAxes |
||
987 | import matplotlib.pyplot as plt |
||
988 | |||
989 | if ax is None or not isinstance(ax, WCSAxes): |
||
990 | ax = Map.from_geom(self.wcs_geom).plot() |
||
991 | |||
992 | kwargs.setdefault("marker", "*") |
||
993 | kwargs.setdefault("color", "tab:blue") |
||
994 | path_effects = kwargs.get("path_effects", None) |
||
995 | |||
996 | xp, yp = self.positions.to_pixel(ax.wcs) |
||
997 | p = ax.scatter(xp, yp, **kwargs) |
||
998 | |||
999 | if path_effects: |
||
1000 | plt.setp(p, path_effects=path_effects) |
||
1001 | |||
1002 | return ax |
||
1003 | |||
1004 | |||
1005 | class Models(DatasetModels, collections.abc.MutableSequence): |
||
1006 | """Sky model collection. |
||
1007 | |||
1008 | Parameters |
||
1009 | ---------- |
||
1010 | models : `SkyModel`, list of `SkyModel` or `Models` |
||
1011 | Sky models |
||
1012 | """ |
||
1013 | |||
1014 | def __delitem__(self, key): |
||
1015 | del self._models[self.index(key)] |
||
1016 | |||
1017 | def __setitem__(self, key, model): |
||
1018 | from gammapy.modeling.models import SkyModel, FoVBackgroundModel |
||
1019 | |||
1020 | if isinstance(model, (SkyModel, FoVBackgroundModel)): |
||
1021 | self._models[self.index(key)] = model |
||
1022 | else: |
||
1023 | raise TypeError(f"Invalid type: {model!r}") |
||
1024 | |||
1025 | def insert(self, idx, model): |
||
1026 | if model.name in self.names: |
||
1027 | raise (ValueError("Model names must be unique")) |
||
1028 | |||
1029 | self._models.insert(idx, model) |
||
1030 | |||
1031 | |||
1032 | class restore_models_status: |
||
1033 | def __init__(self, models, restore_values=True): |
||
1034 | self.restore_values = restore_values |
||
1035 | self.models = models |
||
1036 | self.values = [_.value for _ in models.parameters] |
||
1037 | self.frozen = [_.frozen for _ in models.parameters] |
||
1038 | self.covariance_data = models.covariance.data |
||
1039 | |||
1040 | def __enter__(self): |
||
1041 | pass |
||
1042 | |||
1043 | def __exit__(self, type, value, traceback): |
||
1044 | for value, par, frozen in zip(self.values, self.models.parameters, self.frozen): |
||
1045 | if self.restore_values: |
||
1046 | par.value = value |
||
1047 | par.frozen = frozen |
||
1048 | self.models.covariance = self.covariance_data |
||
1049 |