Conditions | 2 |
Total Lines | 61 |
Code Lines | 46 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
1 | # -*- coding: utf-8 -*- |
||
111 | @pytest.mark.long |
||
112 | def test_proxy_theano(xs, c=3.0, s=1.0): |
||
113 | # Initialize random number generator |
||
114 | np.random.seed(93457) |
||
115 | |||
116 | # proxy "values" |
||
117 | values = ys(xs, c, s) |
||
118 | |||
119 | yp = _test_data(xs, c, s) |
||
120 | yp += 0.5 * np.random.randn(xs.shape[0]) |
||
121 | |||
122 | # using "name" prefixes all variables with <name>_ |
||
123 | with pm.Model(name="proxy") as model: |
||
124 | # amplitude |
||
125 | plamp = pm.Normal("log_amp", mu=0.0, sd=np.log(10.0)) |
||
126 | pamp = pm.Deterministic("amp", pm.math.exp(plamp)) |
||
127 | # lag |
||
128 | pllag = pm.Normal("log_lag", mu=0.0, sd=np.log(10.0)) |
||
129 | plag = pm.Deterministic("lag", pm.math.exp(pllag)) |
||
130 | # lifetime |
||
131 | pltau0 = pm.Normal("log_tau0", mu=0.0, sd=np.log(10.0)) |
||
132 | ptau0 = pm.Deterministic("tau0", pm.math.exp(pltau0)) |
||
133 | cos1 = pm.Normal("tau_cos1", mu=0.0, sd=10.0) |
||
134 | sin1 = pm.Normal("tau_sin1", mu=0.0, sd=10.0) |
||
135 | harm1 = HarmonicModelCosineSine(1., cos1, sin1) |
||
136 | tau1 = LifetimeModel(harm1, lower=0) |
||
137 | |||
138 | proxy = ProxyModel( |
||
139 | xs, values, |
||
140 | amp=pamp, |
||
141 | lag=plag, |
||
142 | tau0=ptau0, |
||
143 | tau_harm=tau1, |
||
144 | tau_scan=10, |
||
145 | ) |
||
146 | prox1 = proxy.get_value(xs) |
||
147 | # Include "jitter" |
||
148 | log_jitter = pm.Normal("log_jitter", mu=0.0, sd=4.0) |
||
149 | pm.Normal("obs", mu=prox1, sd=pm.math.exp(log_jitter), observed=yp) |
||
150 | |||
151 | maxlp0 = pm.find_MAP() |
||
152 | trace = pm.sample( |
||
153 | chains=2, |
||
154 | draws=1000, |
||
155 | tune=1000, |
||
156 | init="jitter+adapt_full", |
||
157 | random_seed=[286923464, 464329682], |
||
158 | return_inferencedata=True, |
||
159 | start=maxlp0, |
||
160 | target_accept=0.9, |
||
161 | ) |
||
162 | |||
163 | medians = trace.posterior.median(dim=("chain", "draw")) |
||
164 | np.testing.assert_allclose( |
||
165 | medians[[ |
||
166 | "proxy_amp", "proxy_lag", "proxy_tau0", |
||
167 | "proxy_tau_cos1", "proxy_tau_sin1", |
||
168 | "proxy_log_jitter", |
||
169 | ]].to_array(), |
||
170 | (3., 2., 1., c, s, np.log(0.5)), |
||
171 | atol=3e-2, rtol=1e-2, |
||
172 | ) |
||
173 |