Passed
Push — master ( b2b3ff...40aa73 )
by Daniel
01:45
created

amd.compare.EMD()   A

Complexity

Conditions 2

Size

Total Lines 42
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 42
rs 9.85
c 0
b 0
f 0
cc 2
nop 5
1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Optional, Union
5
import warnings
6
7
import numpy as np
8
import scipy.spatial    # cdist, pdist, squareform
9
import scipy.optimize   # linear_sum_assignment
10
11
from ._network_simplex import network_simplex
12
from .utils import ETA
13
14
15
def EMD(
16
        pdd: np.ndarray,
17
        pdd_: np.ndarray,
18
        metric: Optional[str] = 'chebyshev',
19
        return_transport: Optional[bool] = False,
20
        **kwargs):
21
    r"""Earth mover's distance (EMD) between two PDDs, also known as
22
    the Wasserstein metric.
23
24
    Parameters
25
    ----------
26
    pdd : ndarray
27
        PDD of a crystal.
28
    pdd\_ : ndarray
29
        PDD of a crystal.
30
    metric : str or callable, optional
31
        EMD between PDDs requires defining a distance between rows of two PDDs.
32
        By default, Chebyshev/l-infinity distance is chosen as with AMDs.
33
        Can take any metric + ``kwargs`` accepted by
34
        ``scipy.spatial.distance.cdist``.
35
    return_transport: bool, optional
36
        Return a tuple (distance, transport_plan) with the optimal transport.
37
38
    Returns
39
    -------
40
    float
41
        Earth mover's distance between PDDs.
42
43
    Raises
44
    ------
45
    ValueError
46
        Thrown if the two PDDs do not have the
47
        same number of columns (``k`` value).
48
    """
49
50
    dm = scipy.spatial.distance.cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
51
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
52
53
    if return_transport:
54
        return emd_dist, transport_plan.reshape(dm.shape)
55
56
    return emd_dist
57
58
59
def SDD_EMD(sdd, sdd_, return_transport: Optional[bool] = False):
60
    r"""Earth mover's distance (EMD) between two SDDs.
61
62
    Parameters
63
    ----------
64
    sdd : tuple of ndarrays
65
        SDD of a crystal.
66
    sdd\_ : tuple of ndarrays
67
        SDD of a crystal.
68
    return_transport: bool, optional
69
        Return a tuple (distance, transport_plan) with the optimal transport.
70
71
    Returns
72
    -------
73
    float
74
        Earth mover's distance between SDDs.
75
76
    Raises
77
    ------
78
    ValueError
79
        Thrown if the two SDDs are not of the same order or do not have the
80
        same number of columns (``k`` value).
81
    """
82
83
    dists, dists_ = sdd[2], sdd_[2]
84
85
    # first order SDD, equivalent to PDD
86
    if dists.ndim == 2 and dists_.ndim == 2:
87
        dm = scipy.spatial.distance.cdist(dists, dists_, metric='chebyshev')
88
        emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
89
90
        if return_transport:
91
            return emd_dist, transport_plan.reshape(dm.shape)
92
93
        return emd_dist
94
95
    order = dists.shape[-1]
96
    n, m = len(sdd[0]), len(sdd_[0])
97
98
    dist_cdist = None
99
    if order == 2:
100
        dist_cdist = np.abs(sdd[1][:, None] - sdd_[1])
101
    else:
102
        dist, dist_ = sdd[1], sdd_[1]
103
104
        # take EMDs between finite PDDs in dist column
105
        weights = np.full((order, ), 1 / order)
106
        dist_cdist = np.empty((n, m), dtype=np.float64)
107
        for i in range(n):
108
            for j in range(m):
109
                finite_pdd_dm = scipy.spatial.distance.cdist(dist[i], dist_[j], metric='chebyshev')
110
                dists_emd, _ = network_simplex(weights, weights, finite_pdd_dm)
111
                dist_cdist[i, j] = dists_emd
112
113
        # flatten and compare by linf
114
        # flat_dist = dist.reshape((n, order * (order - 1)))
115
        # flat_dist_ = dist_.reshape((m, order * (order - 1)))
116
        # flat_dist = np.sort(flat_dist, axis=-1)
117
        # flat_dist_ = np.sort(flat_dist_, axis=-1)
118
        # dist_cdist = scipy.spatial.distance.cdist(flat_dist, flat_dist_, metric='chebyshev')
119
120
    dm = np.empty((n, m), dtype=np.float64)
121
    for i in range(n):
122
        for j in range(m):
123
            cost_matrix = scipy.spatial.distance.cdist(dists[i], dists_[j], metric='chebyshev')
124
            row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
125
            dm[i, j] = max(np.amax(cost_matrix[row_ind, col_ind]), dist_cdist[i, j])
126
127
    emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
128
129
    if return_transport:
130
        return emd_dist, transport_plan.reshape(dm.shape)
131
132
    return emd_dist
133
134
135
def AMD_cdist(
136
        amds: Union[np.ndarray, List[np.ndarray]],
137
        amds_: Union[np.ndarray, List[np.ndarray]],
138
        metric: str = 'chebyshev',
139
        low_memory: bool = False,
140
        **kwargs
141
) -> np.ndarray:
142
    r"""Compare two sets of AMDs with each other, returning a distance matrix.
143
144
    Parameters
145
    ----------
146
    amds : array_like
147
        A list of AMDs.
148
    amds\_ : array_like
149
        A list of AMDs.
150
    metric : str or callable, optional
151
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
152
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
153
    low_memory : bool, optional
154
        Use a slower but more memory efficient method for
155
        large collections of AMDs (Chebyshev/l-inf distance only).
156
157
    Returns
158
    -------
159
    ndarray
160
        A distance matrix shape ``(len(amds), len(amds_))``.
161
        The :math:`ij` th entry is the distance between ``amds[i]``
162
        and ``amds[j]`` given by the ``metric``.
163
    """
164
165
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
166
167
    if len(amds.shape) == 1:
168
        amds = np.array([amds])
169
    if len(amds_.shape) == 1:
170
        amds_ = np.array([amds_])
171
172
    if low_memory:
173
        if metric != 'chebyshev':
174
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
175
176
        dm = np.empty((len(amds), len(amds_)))
177
        for i, amd_vec in enumerate(amds):
178
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
179
    else:
180
        dm = scipy.spatial.distance.cdist(amds, amds_, metric=metric, **kwargs)
181
182
    return dm
183
184
185
def AMD_pdist(
186
        amds: Union[np.ndarray, List[np.ndarray]],
187
        metric: str = 'chebyshev',
188
        low_memory: bool = False,
189
        **kwargs
190
) -> np.ndarray:
191
    """Compare a set of AMDs pairwise, returning a condensed distance matrix.
192
193
    Parameters
194
    ----------
195
    amds : array_like
196
        An array/list of AMDs.
197
    metric : str or callable, optional
198
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
199
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
200
    low_memory : bool, optional
201
        Optionally use a slightly slower but more memory efficient method for
202
        large collections of AMDs (Chebyshev/l-inf distance only).
203
204
    Returns
205
    -------
206
    ndarray
207
        Returns a condensed distance matrix. Collapses a square distance
208
        matrix into a vector just keeping the upper half. Use
209
        ``scipy.spatial.distance.squareform`` to convert to a square distance matrix.
210
    """
211
212
    amds = np.asarray(amds)
213
214
    if len(amds.shape) == 1:
215
        amds = np.array([amds])
216
217
    if low_memory:
218
        m = len(amds)
219
        if metric != 'chebyshev':
220
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
221
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
222
        ind = 0
223
        for i in range(m):
224
            ind_ = ind + m - i - 1
225
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
226
            ind = ind_
227
    else:
228
        cdm = scipy.spatial.distance.pdist(amds, metric=metric, **kwargs)
229
230
    return cdm
231
232
233
def PDD_cdist(
234
        pdds: List[np.ndarray],
235
        pdds_: List[np.ndarray],
236
        metric: str = 'chebyshev',
237
        verbose=False,
238
        **kwargs
239
) -> np.ndarray:
240
    r"""Compare two sets of PDDs with each other, returning a distance matrix.
241
242
    Parameters
243
    ----------
244
    pdds : list of ndarrays
245
        A list of PDDs.
246
    pdds\_ : list of ndarrays
247
        A list of PDDs.
248
    metric : str or callable, optional
249
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
250
        Can take any metric + kwargs accepted by
251
        ``scipy.spatial.distance.cdist``.
252
253
    Returns
254
    -------
255
    ndarray
256
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``.
257
        The :math:`ij` th entry is the distance between ``pdds[i]``
258
        and ``pdds_[j]`` given by Earth mover's distance.
259
    """
260
261
    if isinstance(pdds, np.ndarray):
262
        if len(pdds.shape) == 2:
263
            pdds = [pdds]
264
265
    if isinstance(pdds_, np.ndarray):
266
        if len(pdds_.shape) == 2:
267
            pdds_ = [pdds_]
268
269
    n, m = len(pdds), len(pdds_)
270
    dm = np.empty((n, m))
271
    if verbose:
272
        update_rate = (n * m) // 10000
273
        eta = ETA(n * m, update_rate=update_rate)
274
275
    for i in range(n):
276
        pdd = pdds[i]
277
        for j in range(m):
278
            dm[i, j] = EMD(pdd, pdds_[j], metric=metric, **kwargs)
279
            if verbose:
280
                eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case verbose on line 271 is False. Are you sure this can never be the case?
Loading history...
281
282
    return dm
283
284
285
def PDD_pdist(
286
        pdds: List[np.ndarray],
287
        metric: str = 'chebyshev',
288
        verbose=False,
289
        **kwargs
290
) -> np.ndarray:
291
    """Compare a set of PDDs pairwise, returning a condensed distance matrix.
292
293
    Parameters
294
    ----------
295
    pdds : list of ndarrays
296
        A list of PDDs.
297
    metric : str or callable, optional
298
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
299
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
300
301
    Returns
302
    -------
303
    ndarray
304
        Returns a condensed distance matrix. Collapses a square
305
        distance matrix into a vector just keeping the upper half. Use
306
        ``scipy.spatial.distance.squareform`` to convert to a square
307
        distance matrix.
308
    """
309
310
    if isinstance(pdds, np.ndarray):
311
        if len(pdds.shape) == 2:
312
            pdds = [pdds]
313
314
    m = len(pdds)
315
    cdm_len = (m * (m - 1)) // 2
316
    cdm = np.empty(cdm_len, dtype=np.double)
317
    if verbose:
318
        eta = ETA(cdm_len, update_rate = cdm_len // 10000)
0 ignored issues
show
Coding Style introduced by
No space allowed around keyword argument assignment
Loading history...
319
    inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable j does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
320
321
    for r, (i, j) in enumerate(inds):
322
        cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
323
        if verbose:
324
            eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case verbose on line 317 is False. Are you sure this can never be the case?
Loading history...
325
326
    return cdm
327
328
329
def emd(
330
        pdd: np.ndarray,
331
        pdd_: np.ndarray,
332
        metric: Optional[str] = 'chebyshev',
333
        return_transport: Optional[bool] = False,
334
        **kwargs):
335
    """Alias for amd.emd()."""
336
    return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs)
337