Test Setup Failed
Push — master ( ad57f0...44572e )
by Daniel
07:12
created

amd.compare.compare()   F

Complexity

Conditions 28

Size

Total Lines 248
Code Lines 110

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 110
dl 0
loc 248
rs 0
c 0
b 0
f 0
cc 28
nop 8

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd.compare.compare() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Functions for comparing AMDs and PDDs of crystals."""
2
3
from typing import List, Optional, Union, Tuple
4
from functools import partial
5
from itertools import combinations
6
7
import numpy as np
8
import numba
9
from scipy.spatial.distance import cdist, pdist
10
from joblib import Parallel, delayed
11
import tqdm
12
13
from ._emd import network_simplex
14
from ._types import FloatArray
15
16
__all__ = ["EMD", "AMD_cdist", "AMD_pdist", "PDD_cdist", "PDD_pdist"]
17
18
19
def EMD(
20
    pdd: FloatArray,
21
    pdd_: FloatArray,
22
    metric: Optional[str] = "chebyshev",
23
    return_transport: Optional[bool] = False,
24
    **kwargs,
25
) -> Union[float, Tuple[float, FloatArray]]:
26
    r"""Calculate the Earth mover's distance (EMD) between two PDDs, aka
27
    the Wasserstein metric.
28
29
    Parameters
30
    ----------
31
    pdd : :class:`numpy.ndarray`
32
        PDD of a crystal.
33
    pdd\_ : :class:`numpy.ndarray`
34
        PDD of a crystal.
35
    metric : str or callable, default 'chebyshev'
36
        EMD between PDDs requires defining a distance between PDD rows.
37
        By default, Chebyshev (L-infinity) distance is chosen as with
38
        AMDs. Accepts any metric accepted by
39
        :func:`scipy.spatial.distance.cdist`.
40
    return_transport: bool, default False
41
        Instead return a tuple ``(emd, transport_plan)`` where
42
        transport_plan describes the optimal flow.
43
44
    Returns
45
    -------
46
    emd : float
47
        Earth mover's distance between two PDDs. If ``return_transport``
48
        is True, return a tuple (emd, transport_plan).
49
50
    Raises
51
    ------
52
    ValueError
53
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
54
        columns.
55
    """
56
57
    emd_dist, transport_plan = _EMD(
58
        pdd[:, 0], pdd_[:, 0], pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs
59
    )
60
    if return_transport:
61
        return emd_dist, transport_plan
62
    return emd_dist
63
64
65
def _EMD(
66
    weights: FloatArray,
67
    weights_: FloatArray,
68
    dist: FloatArray,
69
    dist_: FloatArray,
70
    metric: Optional[str] = None,
71
    **kwargs,
72
) -> Tuple[float, FloatArray]:
73
    r"""Calculate the earth mover's distance (EMD) between two weighted
74
    distributions (collections of vectors).
75
76
    Parameters
77
    ----------
78
    dist : :class:`numpy.ndarray`
79
        ``(n, d)`` array of items in the first distribution.
80
    dist_ : :class:`numpy.ndarray`
81
        ``(m, d)`` array of items in the second distribution.
82
    weights : :class:`numpy.ndarray`
83
        Weights of items in ``dist``.
84
    weights\_ : :class:`numpy.ndarray`
85
        Weights of items in ``dist\_``.
86
    metric : str or callable, default 'chebyshev'
87
        Metric used as the base distance between items in ``dist`` and
88
        ``dist\_``. For a list of accepted metrics see
89
        :func:`scipy.spatial.distance.cdist`.
90
91
    Returns
92
    -------
93
    emd : float
94
        Earth mover's distance between two PDDs. If ``return_transport``
95
        is True, returns a tuple (emd, transport_plan).
96
    transport_plan : :class:`numpy.ndarray`
97
        Matrix of optimal flows between the two distributions.
98
    """
99
100
    dm = cdist(dist, dist_, metric=metric, **kwargs)
101
    return network_simplex(weights, weights_, dm)
102
103
104
def AMD_cdist(
105
    amds, amds_, metric: str = "chebyshev", low_memory: bool = False, **kwargs
106
) -> FloatArray:
107
    r"""Compare two sets of AMDs with each other, returning a distance
108
    matrix. This function is essentially
109
    :func:`scipy.spatial.distance.cdist` with the default metric
110
    ``chebyshev`` and a low memory option.
111
112
    Parameters
113
    ----------
114
    amds : ArrayLike
115
        A list/array of AMDs.
116
    amds\_ : ArrayLike
117
        A list/array of AMDs.
118
    metric : str or callable, default 'chebyshev'
119
        Usually AMDs are compared with the Chebyshev (L-infinitys)
120
        distance. Accepts any metric accepted by
121
        :func:`scipy.spatial.distance.cdist`.
122
    low_memory : bool, default False
123
        Use a slower but more memory efficient method for large
124
        collections of AMDs (metric 'chebyshev' only).
125
    **kwargs :
126
        Extra arguments for ``metric``, passed to
127
        :func:`scipy.spatial.distance.cdist`.
128
129
    Returns
130
    -------
131
    dm : :class:`numpy.ndarray`
132
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]``
133
        is the distance (given by ``metric``) between ``amds[i]`` and
134
        ``amds[j]``.
135
    """
136
137
    amds = np.asarray(amds)
138
139
    if low_memory:
140
        if metric != "chebyshev":
141
            raise ValueError(
142
                "'low_memory' parameter of amd.AMD_cdist() only implemented "
143
                "with metric='chebyshev'"
144
            )
145
        dm = np.empty((len(amds), len(amds_)))
146
        for i, amd_vec in enumerate(amds):
147
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
148
    else:
149
        dm = cdist(amds, amds_, metric=metric, **kwargs)
150
    return dm
151
152
153
def AMD_pdist(
154
    amds, metric: str = "chebyshev", low_memory: bool = False, **kwargs
155
) -> FloatArray:
156
    """Compare a set of AMDs pairwise, returning a condensed distance
157
    matrix. This function is essentially
158
    :func:`scipy.spatial.distance.pdist` with the default metric
159
    ``chebyshev`` and a low memory parameter.
160
161
    Parameters
162
    ----------
163
    amds : ArrayLike
164
        An list/array of AMDs.
165
    metric : str or callable, default 'chebyshev'
166
        Usually AMDs are compared with the Chebyshev (L-infinity)
167
        distance. Accepts any metric accepted by
168
        :func:`scipy.spatial.distance.pdist`.
169
    low_memory : bool, default False
170
        Use a slower but more memory efficient method for large
171
        collections of AMDs (metric 'chebyshev' only).
172
    **kwargs :
173
        Extra arguments for ``metric``, passed to
174
        :func:`scipy.spatial.distance.pdist`.
175
176
    Returns
177
    -------
178
    cdm : :class:`numpy.ndarray`
179
        Returns a condensed distance matrix. Collapses a square distance
180
        matrix into a vector, just keeping the upper half. See the
181
        function :func:`squareform <scipy.spatial.distance.squareform>`
182
        from SciPy to convert to a symmetric square distance matrix.
183
    """
184
185
    amds = np.asarray(amds)
186
187
    @numba.njit(cache=True, fastmath=True)
188
    def _pdist_lowmem(amds):
189
        m = amds.shape[0]
190
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.float64)
191
        ind = 0
192
        for i in range(m):
193
            for j in range(i + 1, m):
194
                cdm[ind] = np.amax(np.abs(amds[i] - amds[j]))
195
        return cdm
196
197
    if low_memory:
198
        if metric != "chebyshev":
199
            raise ValueError(
200
                "'low_memory' parameter of amd.AMD_pdist() only implemented "
201
                "with metric='chebyshev'"
202
            )
203
        cdm = _pdist_lowmem(amds)
204
    else:
205
        cdm = pdist(amds, metric=metric, **kwargs)
206
207
    return cdm
208
209
210
def PDD_cdist(
211
    pdds: List[FloatArray],
212
    pdds_: List[FloatArray],
213
    metric: str = "chebyshev",
214
    backend: str = "multiprocessing",
215
    n_jobs: Optional[int] = None,
216
    verbose: bool = False,
217
    **kwargs,
218
) -> FloatArray:
219
    r"""Compare two sets of PDDs with each other, returning a distance
220
    matrix. Supports parallel processing via joblib. If using
221
    parallelisation, make sure to include an if __name__ == '__main__'
222
    guard around this function.
223
224
    Parameters
225
    ----------
226
    pdds : List[:class:`numpy.ndarray`]
227
        A list of PDDs.
228
    pdds\_ : List[:class:`numpy.ndarray`]
229
        A list of PDDs.
230
    metric : str or callable, default 'chebyshev'
231
        Usually PDD rows are compared with the Chebyshev/l-infinity
232
        distance. Accepts any metric accepted by
233
        :func:`scipy.spatial.distance.cdist`.
234
    backend : str, default 'multiprocessing'
235
        The parallelization backend implementation. For a list of
236
        supported backends, see the backend argument of
237
        :class:`joblib.Parallel`.
238
    n_jobs : int, default None
239
        Maximum number of concurrent jobs for parallel processing with
240
        ``joblib``. Set to -1 to use the maximum. Using parallel
241
        processing may be slower for small inputs.
242
    verbose : bool, default False
243
        Prints a progress bar. If using parallel processing
244
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
245
        is used, otherwise uses tqdm.
246
    **kwargs :
247
        Extra arguments for ``metric``, passed to
248
        :func:`scipy.spatial.distance.cdist`.
249
250
    Returns
251
    -------
252
    dm : :class:`numpy.ndarray`
253
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
254
        :math:`ij` th entry is the distance between ``pdds[i]`` and
255
        ``pdds_[j]`` given by Earth mover's distance.
256
    """
257
258
    kwargs.pop("return_transport", None)
259
    k = pdds[0].shape[-1] - 1
260
    _verbose = 3 if verbose else 0
261
262
    if n_jobs is not None and n_jobs not in (0, 1):
263
        # TODO: put results into preallocated empty array in place
264
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
265
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
266
            for i in range(len(pdds))
267
            for j in range(len(pdds_))
268
        )
269
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
270
271
    else:
272
        n, m = len(pdds), len(pdds_)
273
        dm = np.empty((n, m))
274
        if verbose:
275
            desc = f"Comparing {len(pdds)}x{len(pdds_)} PDDs (k={k})"
276
            progress_bar = tqdm.tqdm(desc=desc, total=n * m)
277
            for i in range(n):
278
                for j in range(m):
279
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
280
                    progress_bar.update(1)
281
            progress_bar.close()
282
        else:
283
            for i in range(n):
284
                for j in range(m):
285
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
286
287
    return dm
288
289
290
def PDD_pdist(
291
    pdds: List[FloatArray],
292
    metric: str = "chebyshev",
293
    backend: str = "multiprocessing",
294
    n_jobs: Optional[int] = None,
295
    verbose: bool = False,
296
    **kwargs,
297
) -> FloatArray:
298
    """Compare a set of PDDs pairwise, returning a condensed distance
299
    matrix. Supports parallelisation via joblib. If using
300
    parallelisation, make sure to include a if __name__ == '__main__'
301
    guard around this function.
302
303
    Parameters
304
    ----------
305
    pdds : List[:class:`numpy.ndarray`]
306
        A list of PDDs.
307
    metric : str or callable, default 'chebyshev'
308
        Usually PDD rows are compared with the Chebyshev/l-infinity
309
        distance. Accepts any metric accepted by
310
        :func:`scipy.spatial.distance.cdist`.
311
    backend : str, default 'multiprocessing'
312
        The parallelization backend implementation. For a list of
313
        supported backends, see the backend argument of
314
        :class:`joblib.Parallel`.
315
    n_jobs : int, default None
316
        Maximum number of concurrent jobs for parallel processing with
317
        ``joblib``. Set to -1 to use the maximum. Using parallel
318
        processing may be slower for small inputs.
319
    verbose : bool, default False
320
        Prints a progress bar. If using parallel processing
321
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
322
        is used, otherwise uses tqdm.
323
    **kwargs :
324
        Extra arguments for ``metric``, passed to
325
        :func:`scipy.spatial.distance.cdist`.
326
327
    Returns
328
    -------
329
    cdm : :class:`numpy.ndarray`
330
        Returns a condensed distance matrix. Collapses a square distance
331
        matrix into a vector, just keeping the upper half. See the
332
        function :func:`squareform <scipy.spatial.distance.squareform>`
333
        from SciPy to convert to a symmetric square distance matrix.
334
    """
335
336
    kwargs.pop("return_transport", None)
337
    k = pdds[0].shape[-1] - 1
338
    _verbose = 3 if verbose else 0
339
340
    if n_jobs is not None and n_jobs > 1:
341
        # TODO: put results into preallocated empty array in place
342
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
343
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
344
            for i, j in combinations(range(len(pdds)), 2)
345
        )
346
        cdm = np.array(cdm)
347
348
    else:
349
        m = len(pdds)
350
        cdm_len = (m * (m - 1)) // 2
351
        cdm = np.empty(cdm_len, dtype=np.float64)
352
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
353
        if verbose:
354
            desc = f"Comparing {len(pdds)} PDDs pairwise (k={k})"
355
            progress_bar = tqdm.tqdm(desc=desc, total=cdm_len)
356
            for r, (i, j) in enumerate(inds):
357
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
358
                progress_bar.update(1)
359
            progress_bar.close()
360
        else:
361
            for r, (i, j) in enumerate(inds):
362
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
363
364
    return cdm
365