| Total Complexity | 41 |
| Total Lines | 443 |
| Duplicated Lines | 5.42 % |
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like sciapy.regress.models_theano often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # -*- coding: utf-8 -*- |
||
| 2 | # vim:fileencoding=utf-8 |
||
| 3 | # |
||
| 4 | # Copyright (c) 2022 Stefan Bender |
||
| 5 | # |
||
| 6 | # This module is part of sciapy. |
||
| 7 | # sciapy is free software: you can redistribute it or modify |
||
| 8 | # it under the terms of the GNU General Public License as published |
||
| 9 | # by the Free Software Foundation, version 2. |
||
| 10 | # See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html. |
||
| 11 | """SCIAMACHY regression models (theano/pymc3 version) |
||
| 12 | |||
| 13 | Model classes for SCIAMACHY data regression fits using |
||
| 14 | :mod:`theano` for :mod:`pymc3`. |
||
| 15 | |||
| 16 | This interface is still experimental. |
||
| 17 | """ |
||
| 18 | from __future__ import absolute_import, division, print_function |
||
| 19 | from warnings import warn |
||
| 20 | |||
| 21 | import numpy as np |
||
| 22 | |||
| 23 | try: |
||
| 24 | import aesara_theano_fallback.tensor as tt |
||
| 25 | except ImportError as err: |
||
| 26 | raise ImportError( |
||
| 27 | "The `aesara_theano_fallback` package is required for the `theano` model interface." |
||
| 28 | ).with_traceback(err.__traceback__) |
||
| 29 | try: |
||
| 30 | import pymc3 as pm |
||
| 31 | except ImportError as err: |
||
| 32 | raise ImportError( |
||
| 33 | "The `pymc3` package is required for the `theano` model interface." |
||
| 34 | ).with_traceback(err.__traceback__) |
||
| 35 | |||
| 36 | __all__ = [ |
||
| 37 | "HarmonicModelCosineSine", "HarmonicModelAmpPhase", |
||
| 38 | "LifetimeModel", |
||
| 39 | "ProxyModel", |
||
| 40 | "ModelSet", |
||
| 41 | "setup_proxy_model_theano", |
||
| 42 | "trace_gas_modelset", |
||
| 43 | ] |
||
| 44 | |||
| 45 | |||
| 46 | class HarmonicModelCosineSine: |
||
| 47 | """Model for harmonic terms |
||
| 48 | |||
| 49 | Models harmonic terms using a cosine and sine part. |
||
| 50 | The total amplitude and phase can be inferred from that. |
||
| 51 | |||
| 52 | Parameters |
||
| 53 | ---------- |
||
| 54 | freq : float |
||
| 55 | The frequency in years^-1 |
||
| 56 | cos : float |
||
| 57 | The amplitude of the cosine part |
||
| 58 | sin : float |
||
| 59 | The amplitude of the sine part |
||
| 60 | """ |
||
| 61 | def __init__(self, freq, cos, sin): |
||
| 62 | self.omega = tt.as_tensor_variable(2 * np.pi * freq).astype("float64") |
||
| 63 | self.cos = tt.as_tensor_variable(cos).astype("float64") |
||
| 64 | self.sin = tt.as_tensor_variable(sin).astype("float64") |
||
| 65 | |||
| 66 | def get_value(self, t): |
||
| 67 | t = tt.as_tensor_variable(t).astype("float64") |
||
| 68 | return ( |
||
| 69 | self.cos * tt.cos(self.omega * t) |
||
| 70 | + self.sin * tt.sin(self.omega * t) |
||
| 71 | ) |
||
| 72 | |||
| 73 | def get_amplitude(self): |
||
| 74 | return tt.sqrt(self.cos**2 + self.sin**2) |
||
| 75 | |||
| 76 | def get_phase(self): |
||
| 77 | return tt.arctan2(self.cos, self.sin) |
||
| 78 | |||
| 79 | def compute_gradient(self, t): |
||
| 80 | t = tt.as_tensor_variable(t).astype("float64") |
||
| 81 | dcos = tt.cos(self.omega * t) |
||
| 82 | dsin = tt.sin(self.omega * t) |
||
| 83 | df = 2 * np.pi * t * (self.sin * dcos - self.cos * dsin) |
||
| 84 | return (df, dcos, dsin) |
||
| 85 | |||
| 86 | |||
| 87 | class HarmonicModelAmpPhase: |
||
| 88 | """Model for harmonic terms |
||
| 89 | |||
| 90 | Models harmonic terms using amplitude and phase of a sine. |
||
| 91 | |||
| 92 | Parameters |
||
| 93 | ---------- |
||
| 94 | freq : float |
||
| 95 | The frequency in years^-1 |
||
| 96 | amp : float |
||
| 97 | The amplitude of the harmonic term |
||
| 98 | phase : float |
||
| 99 | The phase of the harmonic part |
||
| 100 | """ |
||
| 101 | def __init__(self, freq, amp, phase): |
||
| 102 | self.omega = tt.as_tensor_variable(2 * np.pi * freq).astype("float64") |
||
| 103 | self.amp = tt.as_tensor_variable(amp).astype("float64") |
||
| 104 | self.phase = tt.as_tensor_variable(phase).astype("float64") |
||
| 105 | |||
| 106 | def get_value(self, t): |
||
| 107 | t = tt.as_tensor_variable(t).astype("float64") |
||
| 108 | return self.amp * tt.sin(self.omega * t + self.phase) |
||
| 109 | |||
| 110 | def get_amplitude(self): |
||
| 111 | return self.amp |
||
| 112 | |||
| 113 | def get_phase(self): |
||
| 114 | return self.phase |
||
| 115 | |||
| 116 | def compute_gradient(self, t): |
||
| 117 | t = tt.as_tensor_variable(t).astype("float64") |
||
| 118 | damp = tt.sin(self.omega * t + self.phase) |
||
| 119 | dphi = self.amp * tt.cos(self.omega * t + self.phase) |
||
| 120 | df = 2 * np.pi * t * dphi |
||
| 121 | return (df, damp, dphi) |
||
| 122 | |||
| 123 | |||
| 124 | class LifetimeModel: |
||
| 125 | """Model for variable lifetime |
||
| 126 | |||
| 127 | Returns the positive values of the sine/cosine. |
||
| 128 | |||
| 129 | Parameters |
||
| 130 | ---------- |
||
| 131 | harmonics : HarmonicModelCosineSine or HarmonicModelAmpPhase or list |
||
| 132 | """ |
||
| 133 | def __init__(self, harmonics, lower=0.): |
||
| 134 | if not hasattr(harmonics, "getitem"): |
||
| 135 | harmonics = [harmonics] |
||
| 136 | self.harmonics = harmonics |
||
| 137 | self.lower = lower |
||
| 138 | |||
| 139 | def get_value(self, t): |
||
| 140 | tau_cs = tt.zeros(t.shape[:-1], dtype="float64") |
||
| 141 | for h in self.harmonics: |
||
| 142 | tau_cs += h.get_value(t) |
||
| 143 | return tt.maximum(self.lower, tau_cs) |
||
| 144 | |||
| 145 | |||
| 146 | def _interp(x, xs, ys, fill_value=0.): |
||
| 147 | idx = xs.searchsorted(x) |
||
| 148 | out_of_bounds = tt.zeros(x.shape[:-1], dtype=bool) |
||
| 149 | out_of_bounds |= (idx < 1) | (idx >= xs.shape[0]) |
||
| 150 | idx = tt.clip(idx, 1, xs.shape[0] - 1) |
||
| 151 | dl = x - xs[idx - 1] |
||
| 152 | dr = xs[idx] - x |
||
| 153 | d = dl + dr |
||
| 154 | wl = dr / d |
||
| 155 | ret = tt.zeros(x.shape[:-1], dtype="float64") |
||
| 156 | ret += wl * ys[idx - 1] + (1 - wl) * ys[idx] |
||
| 157 | ret = tt.switch(out_of_bounds, fill_value, ret) |
||
| 158 | return ret |
||
| 159 | |||
| 160 | |||
| 161 | class ProxyModel: |
||
| 162 | """Model for proxy terms |
||
| 163 | |||
| 164 | Models proxy terms with a finite and (semi-)annually varying life time. |
||
| 165 | |||
| 166 | Parameters |
||
| 167 | ---------- |
||
| 168 | proxy_times : (N,) array_like |
||
| 169 | The data times of the proxy values |
||
| 170 | proxy_vals : (N,) array_like |
||
| 171 | The proxy values at `proxy_times` |
||
| 172 | amp : float |
||
| 173 | The amplitude of the proxy term |
||
| 174 | lag : float, optional |
||
| 175 | The lag of the proxy value in years. |
||
| 176 | tau0 : float, optional |
||
| 177 | The base life time of the proxy |
||
| 178 | tau_harm : LifetimeModel, optional |
||
| 179 | The lifetime harmonic model with a lower limit. |
||
| 180 | tau_scan : float, optional |
||
| 181 | The number of days to sum the previous proxy values. If it is |
||
| 182 | negative, the value will be set to three times the maximal lifetime. |
||
| 183 | No lifetime adjustemets are calculated when set to zero. |
||
| 184 | days_per_time_unit : float, optional |
||
| 185 | The number of days per time unit, used to normalize the lifetime |
||
| 186 | units. Use 365.25 if the times are in fractional years, or 1 if |
||
| 187 | they are in days. |
||
| 188 | Default: 365.25 |
||
| 189 | """ |
||
| 190 | def __init__( |
||
| 191 | self, ptimes, pvalues, amp, |
||
| 192 | lag=0., |
||
| 193 | tau0=0., |
||
| 194 | tau_harm=None, |
||
| 195 | tau_scan=0, |
||
| 196 | days_per_time_unit=365.25, |
||
| 197 | ): |
||
| 198 | # data |
||
| 199 | self.times = tt.as_tensor_variable(ptimes).astype("float64") |
||
| 200 | self.values = tt.as_tensor_variable(pvalues).astype("float64") |
||
| 201 | # parameters |
||
| 202 | self.amp = tt.as_tensor_variable(amp).astype("float64") |
||
| 203 | self.days_per_time_unit = tt.as_tensor_variable(days_per_time_unit).astype("float64") |
||
| 204 | self.lag = tt.as_tensor_variable(lag / days_per_time_unit).astype("float64") |
||
| 205 | self.tau0 = tt.as_tensor_variable(tau0).astype("float64") |
||
| 206 | self.tau_harm = tau_harm |
||
| 207 | self.tau_scan = tau_scan |
||
| 208 | dt = 1.0 |
||
| 209 | bs = np.arange(dt, tau_scan + dt, dt) / days_per_time_unit |
||
| 210 | self.bs = tt.as_tensor_variable(bs).astype("float64") |
||
| 211 | self.dt = tt.as_tensor_variable(dt).astype("float64") |
||
| 212 | # Makes "(m)jd" and "jyear" compatible for the lifetime |
||
| 213 | # seasonal variation. The julian epoch (the default) |
||
| 214 | # is slightly offset with respect to (modified) julian days. |
||
| 215 | self.t_adj = 0. |
||
| 216 | if self.days_per_time_unit == 1: |
||
| 217 | # discriminate between julian days and modified julian days, |
||
| 218 | # 1.8e6 is year 216 in julian days and year 6787 in |
||
| 219 | # modified julian days. It should be pretty safe to judge on |
||
| 220 | # that for most use cases. |
||
| 221 | if self.times[0] > 1.8e6: |
||
| 222 | # julian days |
||
| 223 | self.t_adj = 13. |
||
| 224 | else: |
||
| 225 | # modified julian days |
||
| 226 | self.t_adj = -44.25 |
||
| 227 | self.t_adj = tt.as_tensor_variable(self.t_adj).astype("float64") |
||
| 228 | |||
| 229 | def _lt_corr(self, t, tau): |
||
| 230 | """Lifetime corrected values |
||
| 231 | |||
| 232 | Corrects for a finite lifetime by summing over the last `tmax` |
||
| 233 | days with an exponential decay given of lifetime(s) `tau`. |
||
| 234 | """ |
||
| 235 | yp = tt.zeros(t.shape[:-1], dtype="float64") |
||
| 236 | tauexp = tt.exp(-self.dt / tau) |
||
| 237 | taufac = tt.ones(tau.shape[:-1], dtype="float64") |
||
| 238 | for b in self.bs: |
||
| 239 | taufac *= tauexp |
||
| 240 | yp += taufac * _interp( |
||
| 241 | t - self.lag - b, |
||
| 242 | self.times, self.values, |
||
| 243 | ) |
||
| 244 | return yp * self.dt |
||
| 245 | |||
| 246 | def get_value(self, t): |
||
| 247 | t = tt.as_tensor_variable(t) |
||
| 248 | proxy_val = _interp( |
||
| 249 | t - self.lag, |
||
| 250 | self.times, self.values, |
||
| 251 | ) |
||
| 252 | if self.tau_scan == 0: |
||
| 253 | # no lifetime, nothing else to do |
||
| 254 | return self.amp * proxy_val |
||
| 255 | tau = self.tau0 |
||
| 256 | if self.tau_harm is not None: |
||
| 257 | tau_cs = self.tau_harm.get_value(t + self.t_adj) |
||
| 258 | tau += tau_cs |
||
| 259 | proxy_val += self._lt_corr(t, tau) |
||
| 260 | return self.amp * proxy_val |
||
| 261 | |||
| 262 | |||
| 263 | class ModelSet: |
||
| 264 | def __init__(self, models): |
||
| 265 | self.models = models |
||
| 266 | |||
| 267 | def get_value(self, t): |
||
| 268 | v = tt.zeros(t.shape[:-1], dtype="float64") |
||
| 269 | for m in self.models: |
||
| 270 | v += m.get_value(t) |
||
| 271 | return v |
||
| 272 | |||
| 273 | |||
| 274 | def setup_proxy_model_theano( |
||
| 275 | model, name, |
||
| 276 | times, values, |
||
| 277 | max_amp=1e10, max_days=100, |
||
| 278 | **kwargs |
||
| 279 | ): |
||
| 280 | warn( |
||
| 281 | "This method to set up the `theano`/`pymc3` interface is experimental, " |
||
| 282 | "and the interface will most likely change in future versions. " |
||
| 283 | "It is recommended to use the `ProxyModel` class instead." |
||
| 284 | ) |
||
| 285 | # extract setup from `kwargs` |
||
| 286 | fit_lag = kwargs.get("fit_lag", False) |
||
| 287 | lag = kwargs.get("lag", 0.) |
||
| 288 | lifetime_scan = kwargs.get("lifetime_scan", 60) |
||
| 289 | positive = kwargs.get("positive", False) |
||
| 290 | time_format = kwargs.get("time_format", "jyear") |
||
| 291 | |||
| 292 | with model: |
||
| 293 | if positive: |
||
| 294 | log_amp = pm.Normal("log_{0}_amp".format(name), mu=0.0, sd=np.log(max_amp)) |
||
| 295 | amp = pm.Deterministic("{0}_amp".format(name), pm.math.exp(log_amp)) |
||
| 296 | else: |
||
| 297 | amp = pm.Normal("{0}_amp".format(name), mu=0.0, sd=max_amp) |
||
| 298 | if fit_lag: |
||
| 299 | log_lag = pm.Normal("log_{0}_lag".format(name), mu=-5.0, sd=np.log(max_days)) |
||
| 300 | lag = pm.Deterministic("{0}_lag".format(name), pm.math.exp(log_lag)) |
||
| 301 | if lifetime_scan > 0: |
||
| 302 | log_tau0 = pm.Normal("log_{0}_tau0".format(name), mu=-5.0, sd=np.log(max_days)) |
||
| 303 | tau0 = pm.Deterministic("{0}_tau0".format(name), pm.math.exp(log_tau0)) |
||
| 304 | cos1 = pm.Normal("{0}_tau_cos1".format(name), mu=0.0, sd=max_amp) |
||
| 305 | sin1 = pm.Normal("{0}_tau_sin1".format(name), mu=0.0, sd=max_amp) |
||
| 306 | harm1 = HarmonicModelCosineSine(1., cos1, sin1) |
||
| 307 | tau1 = LifetimeModel(harm1, lower=0) |
||
| 308 | else: |
||
| 309 | tau0 = 0. |
||
| 310 | tau1 = None |
||
| 311 | proxy = ProxyModel( |
||
| 312 | times, values, |
||
| 313 | amp, |
||
| 314 | lag=lag, |
||
| 315 | tau0=tau0, |
||
| 316 | tau_harm=tau1, |
||
| 317 | tau_scan=lifetime_scan, |
||
| 318 | days_per_time_unit=1 if time_format.endswith("d") else 365.25, |
||
| 319 | ) |
||
| 320 | return proxy |
||
| 321 | |||
| 322 | |||
| 323 | View Code Duplication | def _default_proxy_config(tfmt="jyear"): |
|
|
|
|||
| 324 | from .load_data import load_dailymeanLya, load_dailymeanAE |
||
| 325 | proxy_config = {} |
||
| 326 | # Lyman-alpha |
||
| 327 | plyat, plyadf = load_dailymeanLya(tfmt=tfmt) |
||
| 328 | proxy_config.update({ |
||
| 329 | "Lya": { |
||
| 330 | "times": plyat, |
||
| 331 | "values": plyadf["Lya"], |
||
| 332 | "lifetime_scan": 0, |
||
| 333 | "positive": False, |
||
| 334 | } |
||
| 335 | }) |
||
| 336 | # AE index |
||
| 337 | paet, paedf = load_dailymeanAE(name="GM", tfmt=tfmt) |
||
| 338 | proxy_config.update({ |
||
| 339 | "GM": { |
||
| 340 | "times": paet, |
||
| 341 | "values": paedf["GM"], |
||
| 342 | "lifetime_scan": 60, |
||
| 343 | "positive": True, |
||
| 344 | } |
||
| 345 | }) |
||
| 346 | return proxy_config |
||
| 347 | |||
| 348 | |||
| 349 | def trace_gas_modelset(constant=True, freqs=None, proxy_config=None, **kwargs): |
||
| 350 | """Trace gas model set |
||
| 351 | |||
| 352 | Sets up the trace gas model for easy access. All parameters are optional, |
||
| 353 | defaults to an offset, no harmonics, proxies are uncentered and unscaled |
||
| 354 | Lyman-alpha and AE. AE with only positive amplitude and a seasonally |
||
| 355 | varying lifetime. |
||
| 356 | |||
| 357 | Parameters |
||
| 358 | ---------- |
||
| 359 | constant : bool, optional |
||
| 360 | Whether or not to include a constant (offset) term, default is True. |
||
| 361 | freqs : list, optional |
||
| 362 | Frequencies of the harmonic terms in 1 / a^-1 (inverse years). |
||
| 363 | proxy_config : dict, optional |
||
| 364 | Proxy configuration if different from the standard setup. |
||
| 365 | **kwargs : optional |
||
| 366 | Additional keyword arguments, all of them are also passed on to |
||
| 367 | the proxy setup. For now, supported are the following which are |
||
| 368 | also passed along to the proxy setup with |
||
| 369 | `setup_proxy_model_with_bounds()`: |
||
| 370 | |||
| 371 | * fit_phase : bool |
||
| 372 | fit amplitude and phase instead of sine and cosine |
||
| 373 | * scale : float |
||
| 374 | the factor by which the data is scaled, used to constrain |
||
| 375 | the maximum and minimum amplitudes to be fitted. |
||
| 376 | * time_format : string |
||
| 377 | The `astropy.time.Time` format string to setup the time axis. |
||
| 378 | * days_per_time_unit : float |
||
| 379 | The number of days per time unit, used to normalize the frequencies |
||
| 380 | for the harmonic terms. Use 365.25 if the times are in fractional years, |
||
| 381 | 1 if they are in days. Default: 365.25 |
||
| 382 | * max_amp : float |
||
| 383 | Maximum magnitude of the coefficients, used to constrain the |
||
| 384 | parameter search. |
||
| 385 | * max_days : float |
||
| 386 | Maximum magnitude of the lifetimes, used to constrain the |
||
| 387 | parameter search. |
||
| 388 | |||
| 389 | Returns |
||
| 390 | ------- |
||
| 391 | model : :class:`TraceGasModelSet` (extends :class:`celerite.ModelSet`) |
||
| 392 | """ |
||
| 393 | warn( |
||
| 394 | "This method to set up the `theano`/`pymc3` interface is experimental, " |
||
| 395 | "and the interface will most likely change in future versions. " |
||
| 396 | "It is recommended to use the `ProxyModel` class instead." |
||
| 397 | ) |
||
| 398 | fit_phase = kwargs.get("fit_phase", False) |
||
| 399 | scale = kwargs.get("scale", 1e-6) |
||
| 400 | tfmt = kwargs.get("time_format", "jyear") |
||
| 401 | delta_t = kwargs.get("days_per_time_unit", 365.25) |
||
| 402 | |||
| 403 | max_amp = kwargs.pop("max_amp", 1e10 * scale) |
||
| 404 | max_days = kwargs.pop("max_days", 100) |
||
| 405 | |||
| 406 | proxy_config = proxy_config or _default_proxy_config(tfmt=tfmt) |
||
| 407 | |||
| 408 | with pm.Model() as model: |
||
| 409 | offset = 0. |
||
| 410 | if constant: |
||
| 411 | offset = pm.Normal("offset", mu=0.0, sd=max_amp) |
||
| 412 | |||
| 413 | modelset = [] |
||
| 414 | for freq in freqs: |
||
| 415 | if not fit_phase: |
||
| 416 | cos = pm.Normal("cos{0}".format(freq), mu=0., sd=max_amp) |
||
| 417 | sin = pm.Normal("sin{0}".format(freq), mu=0., sd=max_amp) |
||
| 418 | harm = HarmonicModelCosineSine( |
||
| 419 | freq * delta_t / 365.25, |
||
| 420 | cos, sin, |
||
| 421 | ) |
||
| 422 | else: |
||
| 423 | amp = pm.Normal("amp{0}".format(freq), mu=0., sd=max_amp) |
||
| 424 | phase = pm.Normal("phase{0}".format(freq), mu=0., sd=max_amp) |
||
| 425 | harm = HarmonicModelAmpPhase( |
||
| 426 | freq * delta_t / 365.25, |
||
| 427 | amp, phase, |
||
| 428 | ) |
||
| 429 | modelset.append(harm) |
||
| 430 | |||
| 431 | for pn, conf in proxy_config.items(): |
||
| 432 | if "max_amp" not in conf: |
||
| 433 | conf.update(dict(max_amp=max_amp)) |
||
| 434 | if "max_days" not in conf: |
||
| 435 | conf.update(dict(max_days=max_days)) |
||
| 436 | kw = kwargs.copy() # don't mess with the passed arguments |
||
| 437 | kw.update(conf) |
||
| 438 | modelset.append( |
||
| 439 | setup_proxy_model_theano(model, pn, **kw) |
||
| 440 | ) |
||
| 441 | |||
| 442 | return model, ModelSet(modelset), offset |
||
| 443 |