| Conditions | 2 |
| Total Lines | 61 |
| Code Lines | 46 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | # -*- coding: utf-8 -*- |
||
| 111 | @pytest.mark.long |
||
| 112 | def test_proxy_theano(xs, c=3.0, s=1.0): |
||
| 113 | # Initialize random number generator |
||
| 114 | np.random.seed(93457) |
||
| 115 | |||
| 116 | # proxy "values" |
||
| 117 | values = ys(xs, c, s) |
||
| 118 | |||
| 119 | yp = _test_data(xs, c, s) |
||
| 120 | yp += 0.5 * np.random.randn(xs.shape[0]) |
||
| 121 | |||
| 122 | # using "name" prefixes all variables with <name>_ |
||
| 123 | with pm.Model(name="proxy") as model: |
||
| 124 | # amplitude |
||
| 125 | plamp = pm.Normal("log_amp", mu=0.0, sd=np.log(10.0)) |
||
| 126 | pamp = pm.Deterministic("amp", pm.math.exp(plamp)) |
||
| 127 | # lag |
||
| 128 | pllag = pm.Normal("log_lag", mu=0.0, sd=np.log(10.0)) |
||
| 129 | plag = pm.Deterministic("lag", pm.math.exp(pllag)) |
||
| 130 | # lifetime |
||
| 131 | pltau0 = pm.Normal("log_tau0", mu=0.0, sd=np.log(10.0)) |
||
| 132 | ptau0 = pm.Deterministic("tau0", pm.math.exp(pltau0)) |
||
| 133 | cos1 = pm.Normal("tau_cos1", mu=0.0, sd=10.0) |
||
| 134 | sin1 = pm.Normal("tau_sin1", mu=0.0, sd=10.0) |
||
| 135 | harm1 = HarmonicModelCosineSine(1., cos1, sin1) |
||
| 136 | tau1 = LifetimeModel(harm1, lower=0) |
||
| 137 | |||
| 138 | proxy = ProxyModel( |
||
| 139 | xs, values, |
||
| 140 | amp=pamp, |
||
| 141 | lag=plag, |
||
| 142 | tau0=ptau0, |
||
| 143 | tau_harm=tau1, |
||
| 144 | tau_scan=10, |
||
| 145 | ) |
||
| 146 | prox1 = proxy.get_value(xs) |
||
| 147 | # Include "jitter" |
||
| 148 | log_jitter = pm.Normal("log_jitter", mu=0.0, sd=4.0) |
||
| 149 | pm.Normal("obs", mu=prox1, sd=pm.math.exp(log_jitter), observed=yp) |
||
| 150 | |||
| 151 | maxlp0 = pm.find_MAP() |
||
| 152 | trace = pm.sample( |
||
| 153 | chains=2, |
||
| 154 | draws=1000, |
||
| 155 | tune=1000, |
||
| 156 | init="jitter+adapt_full", |
||
| 157 | random_seed=[286923464, 464329682], |
||
| 158 | return_inferencedata=True, |
||
| 159 | start=maxlp0, |
||
| 160 | target_accept=0.9, |
||
| 161 | ) |
||
| 162 | |||
| 163 | medians = trace.posterior.median(dim=("chain", "draw")) |
||
| 164 | np.testing.assert_allclose( |
||
| 165 | medians[[ |
||
| 166 | "proxy_amp", "proxy_lag", "proxy_tau0", |
||
| 167 | "proxy_tau_cos1", "proxy_tau_sin1", |
||
| 168 | "proxy_log_jitter", |
||
| 169 | ]].to_array(), |
||
| 170 | (3., 2., 1., c, s, np.log(0.5)), |
||
| 171 | atol=3e-2, rtol=1e-2, |
||
| 172 | ) |
||
| 173 |