Passed
Push — master ( e05216...ca5a49 )
by Daniel
01:45
created

amd.compare.SDD_EMD()   C

Complexity

Conditions 10

Size

Total Lines 74
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 31
dl 0
loc 74
rs 5.9999
c 0
b 0
f 0
cc 10
nop 3

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd.compare.SDD_EMD() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Tuple, Optional, Union
5
import warnings
6
7
import numpy as np
8
import scipy.spatial    # cdist, pdist, squareform
9
import scipy.optimize   # linear_sum_assignment
10
11
from ._network_simplex import network_simplex
12
from .utils import ETA
13
14
15
def EMD(
16
        pdd: np.ndarray,
17
        pdd_: np.ndarray,
18
        metric: Optional[str] = 'chebyshev',
19
        return_transport: Optional[bool] = False,
20
        **kwargs):
21
    r"""Earth mover's distance (EMD) between two PDDs, also known as
22
    the Wasserstein metric.
23
24
    Parameters
25
    ----------
26
    pdd : ndarray
27
        PDD of a crystal.
28
    pdd\_ : ndarray
29
        PDD of a crystal.
30
    metric : str or callable, optional
31
        EMD between PDDs requires defining a distance between rows of two PDDs.
32
        By default, Chebyshev/l-infinity distance is chosen as with AMDs.
33
        Can take any metric + ``kwargs`` accepted by
34
        ``scipy.spatial.distance.cdist``.
35
    return_transport: bool, optional
36
        Return a tuple (distance, transport_plan) with the optimal transport.
37
38
    Returns
39
    -------
40
    float
41
        Earth mover's distance between PDDs.
42
43
    Raises
44
    ------
45
    ValueError
46
        Thrown if the two PDDs do not have the
47
        same number of columns (``k`` value).
48
    """
49
50
    dm = scipy.spatial.distance.cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
51
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
52
53
    if return_transport:
54
        return emd_dist, transport_plan.reshape(dm.shape)
55
56
    return emd_dist
57
58
59
def SDD_EMD(sdd, sdd_, return_transport: Optional[bool] = False):
60
    r"""Earth mover's distance (EMD) between two SDDs.
61
62
    Parameters
63
    ----------
64
    sdd : tuple of ndarrays
65
        SDD of a crystal.
66
    sdd\_ : tuple of ndarrays
67
        SDD of a crystal.
68
    return_transport: bool, optional
69
        Return a tuple (distance, transport_plan) with the optimal transport.
70
71
    Returns
72
    -------
73
    float
74
        Earth mover's distance between SDDs.
75
76
    Raises
77
    ------
78
    ValueError
79
        Thrown if the two SDDs are not of the same order or do not have the
80
        same number of columns (``k`` value).
81
    """
82
83
    dists, dists_ = sdd[2], sdd_[2]
84
85
    # first order SDD, equivalent to PDD
86
    if dists.ndim == 2 and dists_.ndim == 2:
87
        dm = scipy.spatial.distance.cdist(dists, dists_, metric='chebyshev')
88
        emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
89
90
        if return_transport:
91
            return emd_dist, transport_plan.reshape(dm.shape)
92
93
        return emd_dist
94
95
    order = dists.shape[-1]
96
    n, m = len(sdd[0]), len(sdd_[0])
97
98
    dist_cdist = None
99
    if order == 2:
100
        dist_cdist = np.abs(sdd[1][:, None] - sdd_[1])
101
    else:
102
        dist, dist_ = sdd[1], sdd_[1]
103
104
        # take EMDs between finite PDDs in dist column
105
        weights = np.full((order, ), 1 / order)
106
        dist_cdist = np.empty((n, m), dtype=np.float64)
107
        for i in range(n):
108
            for j in range(m):
109
                finite_pdd_dm = scipy.spatial.distance.cdist(dist[i], dist_[j], metric='chebyshev')
110
                dists_emd, _ = network_simplex(weights, weights, finite_pdd_dm)
111
                dist_cdist[i, j] = dists_emd
112
113
        # flatten and compare by linf
114
        # flat_dist = dist.reshape((n, order * (order - 1)))
115
        # flat_dist_ = dist_.reshape((m, order * (order - 1)))
116
        # flat_dist = np.sort(flat_dist, axis=-1)
117
        # flat_dist_ = np.sort(flat_dist_, axis=-1)
118
        # dist_cdist = scipy.spatial.distance.cdist(flat_dist, flat_dist_, metric='chebyshev')
119
120
    dm = np.empty((n, m), dtype=np.float64)
121
    for i in range(n):
122
        for j in range(m):
123
            cost_matrix = scipy.spatial.distance.cdist(dists[i], dists_[j], metric='chebyshev')
124
            row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
125
            dm[i, j] = max(np.amax(cost_matrix[row_ind, col_ind]), dist_cdist[i, j])
126
127
    emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
128
129
    if return_transport:
130
        return emd_dist, transport_plan.reshape(dm.shape)
131
132
    return emd_dist
133
134
135
def AMD_cdist(
136
        amds: Union[np.ndarray, List[np.ndarray]],
137
        amds_: Union[np.ndarray, List[np.ndarray]],
138
        k: Optional[int] = None,
139
        metric: str = 'chebyshev',
140
        low_memory: bool = False,
141
        **kwargs
142
) -> np.ndarray:
143
    r"""Compare two sets of AMDs with each other, returning a distance matrix.
144
145
    Parameters
146
    ----------
147
    amds : array_like
148
        A list of AMDs.
149
    amds\_ : array_like
150
        A list of AMDs.
151
    k : int, optional
152
        Truncate the AMDs to this value of ``k`` before comparing.
153
    low_memory : bool, optional
154
        Use a slower but more memory efficient method for
155
        large collections of AMDs (Chebyshev/l-inf distance only).
156
    metric : str or callable, optional
157
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
158
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
159
160
    Returns
161
    -------
162
    ndarray
163
        A distance matrix shape ``(len(amds), len(amds_))``.
164
        The :math:`ij` th entry is the distance between ``amds[i]``
165
        and ``amds[j]`` given by the ``metric``.
166
    """
167
168
    amds = np.asarray(amds)
169
    amds_ = np.asarray(amds_)
170
171
    if len(amds.shape) == 1:
172
        amds = np.array([amds])
173
    if len(amds_.shape) == 1:
174
        amds_ = np.array([amds_])
175
176
    if low_memory:
177
        if metric != 'chebyshev':
178
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
179
180
        dm = np.empty((len(amds), len(amds_)))
181
        for i, amd_vec in enumerate(amds):
182
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
183
    else:
184
        dm = scipy.spatial.distance.cdist(amds[:, :k], amds_[:, :k], metric=metric, **kwargs)
185
186
    return dm
187
188
189
def AMD_pdist(
190
        amds: Union[np.ndarray, List[np.ndarray]],
191
        k: Optional[int] = None,
192
        low_memory: bool = False,
193
        metric: str = 'chebyshev',
194
        **kwargs
195
) -> np.ndarray:
196
    """Compare a set of AMDs pairwise, returning a condensed distance matrix.
197
198
    Parameters
199
    ----------
200
    amds : array_like
201
        An array/list of AMDs.
202
    k : int, optional
203
        If :const:`None`, compare whole AMDs (largest ``k``). Set ``k``
204
        to an int to compare for a specific ``k`` (less than the maximum).
205
    low_memory : bool, optional
206
        Optionally use a slightly slower but more memory efficient method for
207
        large collections of AMDs (Chebyshev/l-inf distance only).
208
    metric : str or callable, optional
209
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
210
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
211
212
    Returns
213
    -------
214
    ndarray
215
        Returns a condensed distance matrix. Collapses a square distance
216
        matrix into a vector just keeping the upper half. Use
217
        ``scipy.spatial.distance.squareform`` to convert to a square distance matrix.
218
    """
219
220
    amds = np.asarray(amds)
221
222
    if len(amds.shape) == 1:
223
        amds = np.array([amds])
224
225
    if low_memory:
226
        m = len(amds)
227
        if metric != 'chebyshev':
228
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
229
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
230
        ind = 0
231
        for i in range(m):
232
            ind_ = ind + m - i - 1
233
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
234
            ind = ind_
235
    else:
236
        cdm = scipy.spatial.distance.pdist(amds[:, :k], metric=metric, **kwargs)
237
238
    return cdm
239
240
241
def PDD_cdist(
242
        pdds: List[np.ndarray],
243
        pdds_: List[np.ndarray],
244
        k: Optional[int] = None,
245
        metric: str = 'chebyshev',
246
        verbose: bool = False,
247
        **kwargs
248
) -> np.ndarray:
249
    r"""Compare two sets of PDDs with each other, returning a distance matrix.
250
251
    Parameters
252
    ----------
253
    pdds : list of ndarrays
254
        A list of PDDs.
255
    pdds\_ : list of ndarrays
256
        A list of PDDs.
257
    k : int, optional
258
        If :const:`None`, compare whole PDDs (largest ``k``). Set ``k`` to
259
        an int to compare for a specific ``k`` (less than the maximum).
260
    metric : str or callable, optional
261
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
262
        Can take any metric + kwargs accepted by
263
        ``scipy.spatial.distance.cdist``.
264
    verbose : bool, optional
265
        Optionally print an ETA to terminal as large collections can take
266
        some time.
267
268
    Returns
269
    -------
270
    ndarray
271
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``.
272
        The :math:`ij` th entry is the distance between ``pdds[i]``
273
        and ``pdds_[j]`` given by Earth mover's distance.
274
    """
275
276
    if isinstance(pdds, np.ndarray):
277
        if len(pdds.shape) == 2:
278
            pdds = [pdds]
279
280
    if isinstance(pdds_, np.ndarray):
281
        if len(pdds_.shape) == 2:
282
            pdds_ = [pdds_]
283
284
    n, m = len(pdds), len(pdds_)
285
    t = None if k is None else k + 1
286
    dm = np.empty((n, m))
287
    if verbose:
288
        eta = ETA(n * m)
289
290
    for i in range(n):
291
        pdd = pdds[i]
292
        for j in range(m):
293
            dm[i, j] = EMD(pdd[:, :t], pdds_[j][:, :t], metric=metric, **kwargs)
294
            if verbose:
295
                eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case verbose on line 287 is False. Are you sure this can never be the case?
Loading history...
296
297
    return dm
298
299
300
def PDD_pdist(
301
        pdds: List[np.ndarray],
302
        k: Optional[int] = None,
303
        metric: str = 'chebyshev',
304
        verbose: bool = False,
305
        **kwargs
306
) -> np.ndarray:
307
    """Compare a set of PDDs pairwise, returning a condensed distance matrix.
308
309
    Parameters
310
    ----------
311
    pdds : list of ndarrays
312
        A list of PDDs.
313
    k : int, optional
314
        If :const:`None`, compare whole PDDs (largest ``k``). Set ``k`` to an int
315
        to compare for a specific ``k`` (less than the maximum).
316
    metric : str or callable, optional
317
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
318
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
319
    verbose : bool, optional
320
        Optionally print an ETA to terminal as large collections can take
321
        some time.
322
323
    Returns
324
    -------
325
    ndarray
326
        Returns a condensed distance matrix. Collapses a square
327
        distance matrix into a vector just keeping the upper half. Use
328
        ``scipy.spatial.distance.squareform`` to convert to a square
329
        distance matrix.
330
    """
331
332
    if isinstance(pdds, np.ndarray):
333
        if len(pdds.shape) == 2:
334
            pdds = [pdds]
335
336
    m = len(pdds)
337
    t = None if k is None else k + 1
338
    cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
339
    if verbose:
340
        eta = ETA((m * (m - 1)) // 2)
341
    inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable j does not seem to be defined.
Loading history...
342
343
    for r, (i, j) in enumerate(inds):
344
        cdm[r] = EMD(pdds[i][:, :t], pdds[j][:, :t], metric=metric, **kwargs)
345
        if verbose:
346
            eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case verbose on line 339 is False. Are you sure this can never be the case?
Loading history...
347
348
    return cdm
349
350
351
def emd(
352
        pdd: np.ndarray,
353
        pdd_: np.ndarray,
354
        metric: Optional[str] = 'chebyshev',
355
        return_transport: Optional[bool] = False, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
356
        **kwargs):
357
    """Alias for amd.emd()."""
358
    return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs)
359
360
361
def PDD_cdist_AMD_filter(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
362
        n: int,
363
        pdds: List[np.ndarray],
364
        pdds_: Optional[List[np.ndarray]] = None,
365
        k: Optional[int] = None,
366
        low_memory: bool = False,
367
        metric: str = 'chebyshev',
368
        verbose: bool = False,
369
        **kwargs
370
) -> Tuple[np.ndarray, np.ndarray]:
371
    r"""For each item in ``pdds``, get the ``n`` nearest items in ``pdds_`` by AMD,
372
    then compare references to these nearest items with PDDs.
373
    Tries to comprimise between the speed of AMDs and the accuracy of PDDs.
374
375
    If ``pdds_`` is :const:`None`, this essentially sets ``pdds_ = pdds``, i.e.
376
    do an 'AMD neighbourhood graph' for one set whose weights are PDD distances.
377
378
    Parameters
379
    ----------
380
    n : int
381
        Number of nearest neighbours to find.
382
    pdds : list of ndarrays
383
        A list of PDDs.
384
    pdds\_ : list of ndarrays, optional
385
        A list of PDDs.
386
    k : int, optional
387
        If :const:`None`, compare entire PDDs. Set ``k`` to an int
388
        to compare for a specific ``k`` (less than the maximum).
389
    low_memory : bool, optional
390
        Optionally use a slightly slower but more memory efficient method for
391
        large collections of AMDs (Chebyshev/l-inf distance only).
392
    metric : str or callable, optional
393
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
394
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
395
    verbose : bool, optional
396
        Optionally print an ETA to terminal as large collections can take some time.
397
398
    Returns
399
    -------
400
    tuple of ndarrays (distance_matrix, indices)
401
        For the :math:`i` th item in reference and some :math:`j<n`,
402
        ``distance_matrix[i][j]`` is the distance from reference i to its j-th
403
        nearest neighbour in comparison (after the AMD filter).
404
        ``indices[i][j]`` is the index of said neighbour in ``pdds_``.
405
    """
406
407
    metric_kwargs = {'metric': metric, **kwargs}
408
    kwargs = {'k': k, 'metric': metric, **kwargs}
409
410
    comparison_set_size = len(pdds) if pdds_ is None else len(pdds_)
411
412
    if n >= comparison_set_size:
413
414
        if pdds_ is None:
415
            pdd_cdm = PDD_pdist(pdds, verbose=verbose, **kwargs)
416
            dm = scipy.spatial.distance.squareform(pdd_cdm)
417
        else:
418
            dm = PDD_cdist(pdds, pdds_, verbose=verbose, **kwargs)
419
420
        inds = np.argsort(dm, axis=-1)
421
        dm = np.take_along_axis(dm, inds, axis=-1)
422
        return dm, inds
423
424
    amds = np.array([np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)[:k]
425
                     for pdd in pdds])
426
427
    # one set, pairwise
428
    if pdds_ is None:
429
        pdds_ = pdds
430
        if low_memory:
431
            if metric != 'chebyshev':
432
                warnings.warn(
433
                    "Using only allowed metric 'chebyshev' for low_memory",
434
                    UserWarning)
435
            if verbose:
436
                eta = ETA(len(amds))
437
            inds = []
438
            for i, amd_vec in enumerate(amds):
439
                dists = np.amax(np.abs(amds - amd_vec), axis=-1)
440
                inds_row = np.argpartition(dists, n+1)[:n+1]
441
                inds_row = inds_row[inds_row != i][:n]
442
                inds.append(inds_row)
443
                if verbose:
444
                    eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case verbose on line 435 is False. Are you sure this can never be the case?
Loading history...
445
            inds = np.array(inds)
446
        else:
447
            amd_cdm = AMD_pdist(amds, **kwargs)
448
            amd_dm = scipy.spatial.distance.squareform(amd_cdm)
449
            inds = []
450
            for i, row in enumerate(amd_dm):
451
                inds_row = np.argpartition(row, n+1)[:n+1]
452
                inds_row = inds_row[inds_row != i][:n]
453
                inds.append(inds_row)
454
            inds = np.array(inds)
455
456
    # one set v another
457
    else:
458
        amds_ = np.array([np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)[:k]
459
                          for pdd in pdds_])
460
        if low_memory:
461
            if metric != 'chebyshev':
462
                warnings.warn(
463
                    "Using only allowed metric 'chebyshev' for low_memory",
464
                    UserWarning)
465
            if verbose:
466
                eta = ETA(len(amds) * len(amds_))
467
            inds = []
468
            for i, amd_vec in enumerate(amds):
469
                row = np.amax(np.abs(amds_ - amd_vec), axis=-1)
470
                inds.append(np.argpartition(row, n)[:n])
471
                if verbose:
472
                    eta.update()
473
        else:
474
            amd_dm = AMD_cdist(amds, amds_, low_memory=low_memory, **kwargs)
475
            inds = np.array([np.argpartition(row, n)[:n] for row in amd_dm])
476
477
    dm = np.empty(inds.shape)
478
    if verbose:
479
        eta = ETA(inds.shape[0] * inds.shape[1])
480
    t = None if k is None else k + 1
481
482
    for i, row in enumerate(inds):
483
        for i_, j in enumerate(row):
484
            dm[i, i_] = EMD(pdds[i][:, :t], pdds_[j][:, :t], **metric_kwargs)
485
            if verbose:
486
                eta.update()
487
488
    sorted_inds = np.argsort(dm, axis=-1)
489
    inds = np.take_along_axis(inds, sorted_inds, axis=-1)
490
    dm = np.take_along_axis(dm, sorted_inds, axis=-1)
491
492
    return dm, inds
493