Conditions | 28 |
Total Lines | 248 |
Code Lines | 110 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like amd.compare.compare() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | """Functions for comparing AMDs and PDDs of crystals.""" |
||
40 | return_transport: bool, default False |
||
41 | Instead return a tuple ``(emd, transport_plan)`` where |
||
42 | transport_plan describes the optimal flow. |
||
43 | |||
44 | Returns |
||
45 | ------- |
||
46 | emd : float |
||
47 | Earth mover's distance between two PDDs. If ``return_transport`` |
||
48 | is True, return a tuple (emd, transport_plan). |
||
49 | |||
50 | Raises |
||
51 | ------ |
||
52 | ValueError |
||
53 | Thrown if ``pdd`` and ``pdd_`` do not have the same number of |
||
54 | columns. |
||
55 | """ |
||
56 | |||
57 | emd_dist, transport_plan = _EMD( |
||
58 | pdd[:, 0], pdd_[:, 0], pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs |
||
59 | ) |
||
60 | if return_transport: |
||
61 | return emd_dist, transport_plan |
||
62 | return emd_dist |
||
63 | |||
64 | |||
65 | def _EMD( |
||
66 | weights: FloatArray, |
||
67 | weights_: FloatArray, |
||
68 | dist: FloatArray, |
||
69 | dist_: FloatArray, |
||
70 | metric: Optional[str] = None, |
||
71 | **kwargs, |
||
72 | ) -> Tuple[float, FloatArray]: |
||
73 | r"""Calculate the earth mover's distance (EMD) between two weighted |
||
74 | distributions (collections of vectors). |
||
75 | |||
76 | Parameters |
||
77 | ---------- |
||
78 | dist : :class:`numpy.ndarray` |
||
79 | ``(n, d)`` array of items in the first distribution. |
||
80 | dist_ : :class:`numpy.ndarray` |
||
81 | ``(m, d)`` array of items in the second distribution. |
||
82 | weights : :class:`numpy.ndarray` |
||
83 | Weights of items in ``dist``. |
||
84 | weights\_ : :class:`numpy.ndarray` |
||
85 | Weights of items in ``dist\_``. |
||
86 | metric : str or callable, default 'chebyshev' |
||
87 | Metric used as the base distance between items in ``dist`` and |
||
88 | ``dist\_``. For a list of accepted metrics see |
||
89 | :func:`scipy.spatial.distance.cdist`. |
||
90 | |||
91 | Returns |
||
92 | ------- |
||
93 | emd : float |
||
94 | Earth mover's distance between two PDDs. If ``return_transport`` |
||
95 | is True, returns a tuple (emd, transport_plan). |
||
96 | transport_plan : :class:`numpy.ndarray` |
||
97 | Matrix of optimal flows between the two distributions. |
||
98 | """ |
||
99 | |||
100 | dm = cdist(dist, dist_, metric=metric, **kwargs) |
||
101 | return network_simplex(weights, weights_, dm) |
||
102 | |||
103 | |||
104 | def AMD_cdist( |
||
105 | amds, amds_, metric: str = "chebyshev", low_memory: bool = False, **kwargs |
||
106 | ) -> FloatArray: |
||
107 | r"""Compare two sets of AMDs with each other, returning a distance |
||
108 | matrix. This function is essentially |
||
109 | :func:`scipy.spatial.distance.cdist` with the default metric |
||
110 | ``chebyshev`` and a low memory option. |
||
111 | |||
112 | Parameters |
||
113 | ---------- |
||
114 | amds : ArrayLike |
||
115 | A list/array of AMDs. |
||
116 | amds\_ : ArrayLike |
||
117 | A list/array of AMDs. |
||
118 | metric : str or callable, default 'chebyshev' |
||
119 | Usually AMDs are compared with the Chebyshev (L-infinitys) |
||
120 | distance. Accepts any metric accepted by |
||
121 | :func:`scipy.spatial.distance.cdist`. |
||
122 | low_memory : bool, default False |
||
123 | Use a slower but more memory efficient method for large |
||
124 | collections of AMDs (metric 'chebyshev' only). |
||
125 | **kwargs : |
||
126 | Extra arguments for ``metric``, passed to |
||
127 | :func:`scipy.spatial.distance.cdist`. |
||
128 | |||
129 | Returns |
||
130 | ------- |
||
131 | dm : :class:`numpy.ndarray` |
||
132 | A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` |
||
133 | is the distance (given by ``metric``) between ``amds[i]`` and |
||
134 | ``amds[j]``. |
||
135 | """ |
||
136 | |||
137 | amds = np.asarray(amds) |
||
138 | |||
139 | if low_memory: |
||
140 | if metric != "chebyshev": |
||
141 | raise ValueError( |
||
142 | "'low_memory' parameter of amd.AMD_cdist() only implemented " |
||
143 | "with metric='chebyshev'" |
||
144 | ) |
||
145 | dm = np.empty((len(amds), len(amds_))) |
||
146 | for i, amd_vec in enumerate(amds): |
||
147 | dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1) |
||
148 | else: |
||
149 | dm = cdist(amds, amds_, metric=metric, **kwargs) |
||
150 | return dm |
||
151 | |||
152 | |||
153 | def AMD_pdist( |
||
154 | amds, metric: str = "chebyshev", low_memory: bool = False, **kwargs |
||
155 | ) -> FloatArray: |
||
156 | """Compare a set of AMDs pairwise, returning a condensed distance |
||
157 | matrix. This function is essentially |
||
158 | :func:`scipy.spatial.distance.pdist` with the default metric |
||
159 | ``chebyshev`` and a low memory parameter. |
||
160 | |||
161 | Parameters |
||
162 | ---------- |
||
163 | amds : ArrayLike |
||
164 | An list/array of AMDs. |
||
165 | metric : str or callable, default 'chebyshev' |
||
166 | Usually AMDs are compared with the Chebyshev (L-infinity) |
||
167 | distance. Accepts any metric accepted by |
||
168 | :func:`scipy.spatial.distance.pdist`. |
||
169 | low_memory : bool, default False |
||
170 | Use a slower but more memory efficient method for large |
||
171 | collections of AMDs (metric 'chebyshev' only). |
||
172 | **kwargs : |
||
173 | Extra arguments for ``metric``, passed to |
||
174 | :func:`scipy.spatial.distance.pdist`. |
||
175 | |||
176 | Returns |
||
177 | ------- |
||
178 | cdm : :class:`numpy.ndarray` |
||
179 | Returns a condensed distance matrix. Collapses a square distance |
||
180 | matrix into a vector, just keeping the upper half. See the |
||
181 | function :func:`squareform <scipy.spatial.distance.squareform>` |
||
182 | from SciPy to convert to a symmetric square distance matrix. |
||
183 | """ |
||
184 | |||
185 | amds = np.asarray(amds) |
||
186 | |||
187 | @numba.njit(cache=True, fastmath=True) |
||
188 | def _pdist_lowmem(amds): |
||
189 | m = amds.shape[0] |
||
190 | cdm = np.empty((m * (m - 1)) // 2, dtype=np.float64) |
||
191 | ind = 0 |
||
192 | for i in range(m): |
||
193 | for j in range(i + 1, m): |
||
194 | cdm[ind] = np.amax(np.abs(amds[i] - amds[j])) |
||
195 | return cdm |
||
196 | |||
197 | if low_memory: |
||
198 | if metric != "chebyshev": |
||
199 | raise ValueError( |
||
200 | "'low_memory' parameter of amd.AMD_pdist() only implemented " |
||
201 | "with metric='chebyshev'" |
||
202 | ) |
||
203 | cdm = _pdist_lowmem(amds) |
||
204 | else: |
||
205 | cdm = pdist(amds, metric=metric, **kwargs) |
||
206 | |||
207 | return cdm |
||
208 | |||
209 | |||
210 | def PDD_cdist( |
||
211 | pdds: List[FloatArray], |
||
212 | pdds_: List[FloatArray], |
||
213 | metric: str = "chebyshev", |
||
214 | backend: str = "multiprocessing", |
||
215 | n_jobs: Optional[int] = None, |
||
216 | verbose: bool = False, |
||
217 | **kwargs, |
||
218 | ) -> FloatArray: |
||
219 | r"""Compare two sets of PDDs with each other, returning a distance |
||
220 | matrix. Supports parallel processing via joblib. If using |
||
221 | parallelisation, make sure to include an if __name__ == '__main__' |
||
222 | guard around this function. |
||
223 | |||
224 | Parameters |
||
225 | ---------- |
||
226 | pdds : List[:class:`numpy.ndarray`] |
||
227 | A list of PDDs. |
||
228 | pdds\_ : List[:class:`numpy.ndarray`] |
||
229 | A list of PDDs. |
||
230 | metric : str or callable, default 'chebyshev' |
||
231 | Usually PDD rows are compared with the Chebyshev/l-infinity |
||
232 | distance. Accepts any metric accepted by |
||
233 | :func:`scipy.spatial.distance.cdist`. |
||
234 | backend : str, default 'multiprocessing' |
||
235 | The parallelization backend implementation. For a list of |
||
236 | supported backends, see the backend argument of |
||
237 | :class:`joblib.Parallel`. |
||
238 | n_jobs : int, default None |
||
239 | Maximum number of concurrent jobs for parallel processing with |
||
240 | ``joblib``. Set to -1 to use the maximum. Using parallel |
||
241 | processing may be slower for small inputs. |
||
242 | verbose : bool, default False |
||
243 | Prints a progress bar. If using parallel processing |
||
244 | (n_jobs > 1), the verbose argument of :class:`joblib.Parallel` |
||
245 | is used, otherwise uses tqdm. |
||
246 | **kwargs : |
||
247 | Extra arguments for ``metric``, passed to |
||
248 | :func:`scipy.spatial.distance.cdist`. |
||
249 | |||
250 | Returns |
||
251 | ------- |
||
252 | dm : :class:`numpy.ndarray` |
||
253 | Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The |
||
254 | :math:`ij` th entry is the distance between ``pdds[i]`` and |
||
255 | ``pdds_[j]`` given by Earth mover's distance. |
||
256 | """ |
||
257 | |||
258 | kwargs.pop("return_transport", None) |
||
259 | k = pdds[0].shape[-1] - 1 |
||
260 | _verbose = 3 if verbose else 0 |
||
261 | |||
262 | if n_jobs is not None and n_jobs not in (0, 1): |
||
263 | # TODO: put results into preallocated empty array in place |
||
264 | dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)( |
||
265 | delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j]) |
||
266 | for i in range(len(pdds)) |
||
267 | for j in range(len(pdds_)) |
||
268 | ) |
||
269 | dm = np.array(dm).reshape((len(pdds), len(pdds_))) |
||
270 | |||
271 | else: |
||
272 | n, m = len(pdds), len(pdds_) |
||
273 | dm = np.empty((n, m)) |
||
274 | if verbose: |
||
275 | desc = f"Comparing {len(pdds)}x{len(pdds_)} PDDs (k={k})" |
||
276 | progress_bar = tqdm.tqdm(desc=desc, total=n * m) |
||
277 | for i in range(n): |
||
278 | for j in range(m): |
||
279 | dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs) |
||
280 | progress_bar.update(1) |
||
281 | progress_bar.close() |
||
282 | else: |
||
283 | for i in range(n): |
||
284 | for j in range(m): |
||
285 | dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs) |
||
286 | |||
287 | return dm |
||
288 | |||
365 |