Conditions | 15 |
Total Lines | 157 |
Code Lines | 74 |
Lines | 157 |
Ratio | 100 % |
Tests | 1 |
CRAP Score | 229.7642 |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like sciapy.regress.statistics.waic_loo() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | # -*- coding: utf-8 -*- |
||
52 | 1 | View Code Duplication | def waic_loo(model, times, data, errs, |
53 | samples, |
||
54 | method="likelihood", |
||
55 | train_data=None, |
||
56 | noisy_targets=True, |
||
57 | nthreads=1, |
||
58 | use_dask=False, |
||
59 | dask_cluster=None, |
||
60 | ): |
||
61 | """Watanabe-Akaike information criterion (WAIC) and LOO IC of the (GP) model |
||
62 | |||
63 | Calculates the WAIC and leave-one-out (LOO) cross validation scores and |
||
64 | information criteria (IC) from the MCMC samples of the posterior parameter |
||
65 | distributions. Uses the posterior point-wise (per data point) |
||
66 | probabilities and the formulae from [1]_ and [2]_. |
||
67 | |||
68 | .. [1] Vehtari, Gelman, and Gabry, Stat Comput (2017) 27:1413–1432, |
||
69 | doi: 10.1007/s11222-016-9696-4 |
||
70 | |||
71 | .. [2] Vehtari and Gelman, (unpublished) |
||
72 | http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf |
||
73 | http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf |
||
74 | |||
75 | Parameters |
||
76 | ---------- |
||
77 | model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance |
||
78 | The model instance whose parameter distribution was drawn. |
||
79 | times : (M,) array_like |
||
80 | The test coordinates to predict or evaluate the model on. |
||
81 | data : (M,) array_like |
||
82 | The test data to test the model against. |
||
83 | errs : (M,) array_like |
||
84 | The errors (variances) of the test data. |
||
85 | samples : (K, L) array_like |
||
86 | The `K` MCMC samples of the `L` parameter distributions. |
||
87 | method : str ("likelihood" or "predict"), optional |
||
88 | The method to "predict" the data, the default uses the (log)likelihood |
||
89 | in the same way as is done when fitting (training) the model. |
||
90 | "predict" uses the actual GP prediction, might be useful if the IC |
||
91 | should be estimated for actual test data that was not used to train |
||
92 | the model. |
||
93 | train_data : (N,) array_like, optional |
||
94 | The data on which the model was trained, needed if method="predict" is |
||
95 | used, otherwise None is the default and the likelihood is used. |
||
96 | noisy_targets : bool, optional |
||
97 | Include the given errors when calculating the predictive probability. |
||
98 | nthreads : int, optional |
||
99 | Number of threads to distribute the point-wise probability |
||
100 | calculations to (default: 1). |
||
101 | use_dask : boot, optional |
||
102 | Use `dask.distributed` to distribute the point-wise probability |
||
103 | calculations to `nthreads` workers. The default is to use |
||
104 | `multiprocessing.pool.Pool()`. |
||
105 | dask_cluster: str, or `dask.distributed.Cluster` instance, optional |
||
106 | Will be passed to `dask.distributed.Client()` |
||
107 | This can be the address of a Scheduler server like a string |
||
108 | '127.0.0.1:8786' or a cluster object like `dask.distributed.LocalCluster()`. |
||
109 | |||
110 | Returns |
||
111 | ------- |
||
112 | waic, waic_se, p_waic, loo_ic, loo_se, p_loo : tuple |
||
113 | The WAIC and its standard error as well as the |
||
114 | estimated effective number of parameters, p_waic. |
||
115 | The LOO IC, its standard error, and the estimated |
||
116 | effective number of parameters, p_loo. |
||
117 | """ |
||
118 | from functools import partial |
||
119 | from multiprocessing import pool |
||
120 | from scipy.special import logsumexp |
||
121 | try: |
||
122 | from tqdm.autonotebook import tqdm |
||
123 | except ImportError: |
||
124 | tqdm = None |
||
125 | try: |
||
126 | from dask.distributed import Client, LocalCluster, progress |
||
127 | except ImportError: |
||
128 | use_dask = False |
||
129 | |||
130 | # the predictive covariance should include the data variance |
||
131 | # set to a small value if we don't want to account for them |
||
132 | if not noisy_targets or errs is None: |
||
133 | errs = 1.123e-12 |
||
134 | |||
135 | # point-wise posterior/predictive probabilities |
||
136 | _log_p_pt = partial(_log_lh_pt, model, times, data, errs) |
||
137 | if method == "predict" and train_data is not None: |
||
138 | _log_p_pt = partial(_log_pred_pt, model, train_data, times, data, errs) |
||
139 | |||
140 | # calculate the point-wise probabilities and stack them together |
||
141 | if nthreads > 1: |
||
142 | if use_dask: |
||
143 | if dask_cluster is None: |
||
144 | # start local dask cluster |
||
145 | _cl = LocalCluster(n_workers=nthreads, threads_per_worker=1) |
||
146 | else: |
||
147 | # use provided dask cluster |
||
148 | _cl = dask_cluster |
||
149 | _c = Client(_cl) |
||
150 | _log_pred = _c.map(_log_p_pt, samples) |
||
151 | progress(_log_pred) |
||
152 | log_pred = np.stack(_c.gather(_log_pred)) |
||
153 | _c.close() |
||
154 | if dask_cluster is None: |
||
155 | _cl.close() |
||
156 | else: |
||
157 | # multiprocessing.pool |
||
158 | _p = pool.Pool(processes=nthreads) |
||
159 | _mapped = _p.imap_unordered(_log_p_pt, samples) |
||
160 | if tqdm is not None: |
||
161 | _mapped = tqdm(_mapped, total=len(samples)) |
||
162 | log_pred = np.stack(list(_mapped)) |
||
163 | _p.close() |
||
164 | _p.join() |
||
165 | else: |
||
166 | if tqdm is not None: |
||
167 | samples = tqdm(samples, total=len(samples)) |
||
168 | log_pred = np.stack(list(map(_log_p_pt, samples))) |
||
169 | |||
170 | lppd_i = logsumexp(log_pred, b=1. / log_pred.shape[0], axis=0) |
||
171 | p_waic_i = np.nanvar(log_pred, ddof=1, axis=0) |
||
172 | if np.any(p_waic_i > 0.4): |
||
173 | logging.warn("""For one or more samples the posterior variance of the |
||
174 | log predictive densities exceeds 0.4. This could be indication of |
||
175 | WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details |
||
176 | """) |
||
177 | elpd_i = lppd_i - p_waic_i |
||
178 | waic_i = -2. * elpd_i |
||
179 | waic_se = np.sqrt(len(waic_i) * np.nanvar(waic_i, ddof=1)) |
||
180 | waic = np.nansum(waic_i) |
||
181 | p_waic = np.nansum(p_waic_i) |
||
182 | if 2. * p_waic > len(waic_i): |
||
183 | logging.warn("""p_waic > n / 2, |
||
184 | the WAIC approximation is unreliable. |
||
185 | """) |
||
186 | logging.info("WAIC: %s, waic_se: %s, p_w: %s", waic, waic_se, p_waic) |
||
187 | |||
188 | # LOO |
||
189 | loo_ws = 1. / np.exp(log_pred - np.nanmax(log_pred, axis=0)) |
||
190 | loo_ws_n = loo_ws / np.nanmean(loo_ws, axis=0) |
||
191 | loo_ws_r = np.clip(loo_ws_n, None, np.sqrt(log_pred.shape[0])) |
||
192 | elpd_loo_i = logsumexp(log_pred, |
||
193 | b=loo_ws_r / np.nansum(loo_ws_r, axis=0), |
||
194 | axis=0) |
||
195 | p_loo_i = lppd_i - elpd_loo_i |
||
196 | loo_ic_i = -2 * elpd_loo_i |
||
197 | loo_ic_se = np.sqrt(len(loo_ic_i) * np.nanvar(loo_ic_i)) |
||
198 | loo_ic = np.nansum(loo_ic_i) |
||
199 | p_loo = np.nansum(p_loo_i) |
||
200 | logging.info("loo IC: %s, se: %s, p_loo: %s", loo_ic, loo_ic_se, p_loo) |
||
201 | |||
202 | # van der Linde, 2005, Statistica Neerlandica, 2005 |
||
203 | # https://doi.org/10.1111/j.1467-9574.2005.00278.x |
||
204 | hy1 = -np.nanmean(lppd_i) |
||
205 | hy2 = -np.nanmedian(lppd_i) |
||
206 | logging.info("H(Y): mean %s, median: %s", hy1, hy2) |
||
207 | |||
208 | return waic, waic_se, p_waic, loo_ic, loo_ic_se, p_loo |
||
209 | |||
349 |