Conditions | 3 |
Total Lines | 68 |
Code Lines | 48 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
1 | # -*- coding: utf-8 -*- |
||
116 | @pytest.mark.long |
||
117 | @pytest.mark.parametrize( |
||
118 | "f", |
||
119 | [1., 1. / 365.25] |
||
120 | ) |
||
121 | def test_proxy_theano(xx, f, c=3.0, s=1.0): |
||
122 | # Initialize random number generator |
||
123 | np.random.seed(93457) |
||
124 | |||
125 | dx = 1. / (f * 365.25) |
||
126 | if f < 1.: |
||
127 | xs = xx * dx |
||
128 | else: |
||
129 | # convert to fractional years |
||
130 | xs = 1859 + (xx - 44.25) * dx |
||
131 | # proxy "values" |
||
132 | values = _yy(xs, c, s) |
||
133 | |||
134 | yp = _test_data(xs, values, f, c, s) |
||
135 | yp += 0.5 * np.random.randn(xs.shape[0]) |
||
136 | |||
137 | # using "name" prefixes all variables with <name>_ |
||
138 | with pm.Model(name="proxy") as model: |
||
139 | # amplitude |
||
140 | pamp = pm.Normal("amp", mu=0.0, sigma=4.0) |
||
141 | # lag |
||
142 | plag = pm.Lognormal("lag", mu=0.0, sigma=4.0, testval=1.0) |
||
143 | # lifetime |
||
144 | ptau0 = pm.Lognormal("tau0", mu=0.0, sigma=4.0, testval=1.0) |
||
145 | cos1 = pm.Normal("tau_cos1", mu=0.0, sigma=10.0) |
||
146 | sin1 = pm.Normal("tau_sin1", mu=0.0, sigma=10.0) |
||
147 | harm1 = HarmonicModelCosineSine(f, cos1, sin1) |
||
148 | tau1 = LifetimeModel(harm1, lower=0) |
||
149 | |||
150 | proxy = ProxyModel( |
||
151 | xs, values, |
||
152 | amp=pamp, |
||
153 | lag=plag, |
||
154 | tau0=ptau0, |
||
155 | tau_harm=tau1, |
||
156 | tau_scan=10, |
||
157 | days_per_time_unit=f * 365.25, |
||
158 | ) |
||
159 | prox1 = proxy.get_value(xs) |
||
160 | # Include "jitter" |
||
161 | log_jitter = pm.Normal("log_jitter", mu=0.0, sigma=4.0) |
||
162 | pm.Normal("obs", mu=prox1, sigma=pm.math.exp(log_jitter), observed=yp) |
||
163 | |||
164 | maxlp0 = pm.find_MAP() |
||
165 | trace = pm.sample( |
||
166 | chains=2, |
||
167 | draws=400, |
||
168 | tune=400, |
||
169 | random_seed=[286923464, 464329682], |
||
170 | return_inferencedata=True, |
||
171 | ) |
||
172 | |||
173 | medians = trace.posterior.median(dim=("chain", "draw")) |
||
174 | var_names = [ |
||
175 | model.name_for(n) |
||
176 | for n in [ |
||
177 | "amp", "lag", "tau0", "tau_cos1", "tau_sin1", "log_jitter", |
||
178 | ] |
||
179 | ] |
||
180 | np.testing.assert_allclose( |
||
181 | medians[var_names].to_array(), |
||
182 | (3., 2., 1., c, s, np.log(0.5)), |
||
183 | atol=3e-2, rtol=1e-2, |
||
184 | ) |
||
185 |