Passed
Push — master ( 2d0e3f...af43de )
by Daniel
01:39
created

amd.compare.SDD_EMD()   C

Complexity

Conditions 10

Size

Total Lines 74
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 31
dl 0
loc 74
rs 5.9999
c 0
b 0
f 0
cc 10
nop 3

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd.compare.SDD_EMD() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from turtle import update
0 ignored issues
show
Bug introduced by
The name update does not seem to exist in module turtle.
Loading history...
Unused Code introduced by
Unused update imported from turtle
Loading history...
5
from typing import List, Tuple, Optional, Union
6
import warnings
7
8
import numpy as np
9
import scipy.spatial    # cdist, pdist, squareform
10
import scipy.optimize   # linear_sum_assignment
11
12
from ._network_simplex import network_simplex
13
from .utils import ETA
14
15
16
def set_verbose(setting, update_rate=100):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
17
    global _VERBOSE
0 ignored issues
show
Bug introduced by
Global variable '_VERBOSE' undefined at the module level
Loading history...
18
    global _VERBOSE_UPDATE_RATE
0 ignored issues
show
Bug introduced by
Global variable '_VERBOSE_UPDATE_RATE' undefined at the module level
Loading history...
19
    _VERBOSE = setting
20
    _VERBOSE_UPDATE_RATE = update_rate
21
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
22
set_verbose(False)
23
24
def EMD(
25
        pdd: np.ndarray,
26
        pdd_: np.ndarray,
27
        metric: Optional[str] = 'chebyshev',
28
        return_transport: Optional[bool] = False,
29
        **kwargs):
30
    r"""Earth mover's distance (EMD) between two PDDs, also known as
31
    the Wasserstein metric.
32
33
    Parameters
34
    ----------
35
    pdd : ndarray
36
        PDD of a crystal.
37
    pdd\_ : ndarray
38
        PDD of a crystal.
39
    metric : str or callable, optional
40
        EMD between PDDs requires defining a distance between rows of two PDDs.
41
        By default, Chebyshev/l-infinity distance is chosen as with AMDs.
42
        Can take any metric + ``kwargs`` accepted by
43
        ``scipy.spatial.distance.cdist``.
44
    return_transport: bool, optional
45
        Return a tuple (distance, transport_plan) with the optimal transport.
46
47
    Returns
48
    -------
49
    float
50
        Earth mover's distance between PDDs.
51
52
    Raises
53
    ------
54
    ValueError
55
        Thrown if the two PDDs do not have the
56
        same number of columns (``k`` value).
57
    """
58
59
    dm = scipy.spatial.distance.cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
60
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
61
62
    if return_transport:
63
        return emd_dist, transport_plan.reshape(dm.shape)
64
65
    return emd_dist
66
67
68
def SDD_EMD(sdd, sdd_, return_transport: Optional[bool] = False):
69
    r"""Earth mover's distance (EMD) between two SDDs.
70
71
    Parameters
72
    ----------
73
    sdd : tuple of ndarrays
74
        SDD of a crystal.
75
    sdd\_ : tuple of ndarrays
76
        SDD of a crystal.
77
    return_transport: bool, optional
78
        Return a tuple (distance, transport_plan) with the optimal transport.
79
80
    Returns
81
    -------
82
    float
83
        Earth mover's distance between SDDs.
84
85
    Raises
86
    ------
87
    ValueError
88
        Thrown if the two SDDs are not of the same order or do not have the
89
        same number of columns (``k`` value).
90
    """
91
92
    dists, dists_ = sdd[2], sdd_[2]
93
94
    # first order SDD, equivalent to PDD
95
    if dists.ndim == 2 and dists_.ndim == 2:
96
        dm = scipy.spatial.distance.cdist(dists, dists_, metric='chebyshev')
97
        emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
98
99
        if return_transport:
100
            return emd_dist, transport_plan.reshape(dm.shape)
101
102
        return emd_dist
103
104
    order = dists.shape[-1]
105
    n, m = len(sdd[0]), len(sdd_[0])
106
107
    dist_cdist = None
108
    if order == 2:
109
        dist_cdist = np.abs(sdd[1][:, None] - sdd_[1])
110
    else:
111
        dist, dist_ = sdd[1], sdd_[1]
112
113
        # take EMDs between finite PDDs in dist column
114
        weights = np.full((order, ), 1 / order)
115
        dist_cdist = np.empty((n, m), dtype=np.float64)
116
        for i in range(n):
117
            for j in range(m):
118
                finite_pdd_dm = scipy.spatial.distance.cdist(dist[i], dist_[j], metric='chebyshev')
119
                dists_emd, _ = network_simplex(weights, weights, finite_pdd_dm)
120
                dist_cdist[i, j] = dists_emd
121
122
        # flatten and compare by linf
123
        # flat_dist = dist.reshape((n, order * (order - 1)))
124
        # flat_dist_ = dist_.reshape((m, order * (order - 1)))
125
        # flat_dist = np.sort(flat_dist, axis=-1)
126
        # flat_dist_ = np.sort(flat_dist_, axis=-1)
127
        # dist_cdist = scipy.spatial.distance.cdist(flat_dist, flat_dist_, metric='chebyshev')
128
129
    dm = np.empty((n, m), dtype=np.float64)
130
    for i in range(n):
131
        for j in range(m):
132
            cost_matrix = scipy.spatial.distance.cdist(dists[i], dists_[j], metric='chebyshev')
133
            row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
134
            dm[i, j] = max(np.amax(cost_matrix[row_ind, col_ind]), dist_cdist[i, j])
135
136
    emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
137
138
    if return_transport:
139
        return emd_dist, transport_plan.reshape(dm.shape)
140
141
    return emd_dist
142
143
144
def AMD_cdist(
145
        amds: Union[np.ndarray, List[np.ndarray]],
146
        amds_: Union[np.ndarray, List[np.ndarray]],
147
        metric: str = 'chebyshev',
148
        low_memory: bool = False,
149
        **kwargs
150
) -> np.ndarray:
151
    r"""Compare two sets of AMDs with each other, returning a distance matrix.
152
153
    Parameters
154
    ----------
155
    amds : array_like
156
        A list of AMDs.
157
    amds\_ : array_like
158
        A list of AMDs.
159
    metric : str or callable, optional
160
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
161
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
162
    low_memory : bool, optional
163
        Use a slower but more memory efficient method for
164
        large collections of AMDs (Chebyshev/l-inf distance only).
165
166
    Returns
167
    -------
168
    ndarray
169
        A distance matrix shape ``(len(amds), len(amds_))``.
170
        The :math:`ij` th entry is the distance between ``amds[i]``
171
        and ``amds[j]`` given by the ``metric``.
172
    """
173
174
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
175
176
    if len(amds.shape) == 1:
177
        amds = np.array([amds])
178
    if len(amds_.shape) == 1:
179
        amds_ = np.array([amds_])
180
181
    if low_memory:
182
        if metric != 'chebyshev':
183
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
184
185
        dm = np.empty((len(amds), len(amds_)))
186
        for i, amd_vec in enumerate(amds):
187
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
188
    else:
189
        dm = scipy.spatial.distance.cdist(amds, amds_, metric=metric, **kwargs)
190
191
    return dm
192
193
194
def AMD_pdist(
195
        amds: Union[np.ndarray, List[np.ndarray]],
196
        metric: str = 'chebyshev',
197
        low_memory: bool = False,
198
        **kwargs
199
) -> np.ndarray:
200
    """Compare a set of AMDs pairwise, returning a condensed distance matrix.
201
202
    Parameters
203
    ----------
204
    amds : array_like
205
        An array/list of AMDs.
206
    metric : str or callable, optional
207
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
208
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
209
    low_memory : bool, optional
210
        Optionally use a slightly slower but more memory efficient method for
211
        large collections of AMDs (Chebyshev/l-inf distance only).
212
213
    Returns
214
    -------
215
    ndarray
216
        Returns a condensed distance matrix. Collapses a square distance
217
        matrix into a vector just keeping the upper half. Use
218
        ``scipy.spatial.distance.squareform`` to convert to a square distance matrix.
219
    """
220
221
    amds = np.asarray(amds)
222
223
    if len(amds.shape) == 1:
224
        amds = np.array([amds])
225
226
    if low_memory:
227
        m = len(amds)
228
        if metric != 'chebyshev':
229
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
230
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
231
        ind = 0
232
        for i in range(m):
233
            ind_ = ind + m - i - 1
234
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
235
            ind = ind_
236
    else:
237
        cdm = scipy.spatial.distance.pdist(amds, metric=metric, **kwargs)
238
239
    return cdm
240
241
242
def PDD_cdist(
243
        pdds: List[np.ndarray],
244
        pdds_: List[np.ndarray],
245
        metric: str = 'chebyshev',
246
        **kwargs
247
) -> np.ndarray:
248
    r"""Compare two sets of PDDs with each other, returning a distance matrix.
249
250
    Parameters
251
    ----------
252
    pdds : list of ndarrays
253
        A list of PDDs.
254
    pdds\_ : list of ndarrays
255
        A list of PDDs.
256
    metric : str or callable, optional
257
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
258
        Can take any metric + kwargs accepted by
259
        ``scipy.spatial.distance.cdist``.
260
261
    Returns
262
    -------
263
    ndarray
264
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``.
265
        The :math:`ij` th entry is the distance between ``pdds[i]``
266
        and ``pdds_[j]`` given by Earth mover's distance.
267
    """
268
269
    if isinstance(pdds, np.ndarray):
270
        if len(pdds.shape) == 2:
271
            pdds = [pdds]
272
273
    if isinstance(pdds_, np.ndarray):
274
        if len(pdds_.shape) == 2:
275
            pdds_ = [pdds_]
276
277
    n, m = len(pdds), len(pdds_)
278
    dm = np.empty((n, m))
279
    if _VERBOSE:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE does not seem to be defined.
Loading history...
280
        eta = ETA(n * m, update_rate=_VERBOSE_UPDATE_RATE)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE_UPDATE_RATE does not seem to be defined.
Loading history...
281
282
    for i in range(n):
283
        pdd = pdds[i]
284
        for j in range(m):
285
            dm[i, j] = EMD(pdd, pdds_[j], metric=metric, **kwargs)
286
            if _VERBOSE:
287
                eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case _VERBOSE on line 279 is False. Are you sure this can never be the case?
Loading history...
288
289
    return dm
290
291
292
def PDD_pdist(
293
        pdds: List[np.ndarray],
294
        metric: str = 'chebyshev',
295
        **kwargs
296
) -> np.ndarray:
297
    """Compare a set of PDDs pairwise, returning a condensed distance matrix.
298
299
    Parameters
300
    ----------
301
    pdds : list of ndarrays
302
        A list of PDDs.
303
    metric : str or callable, optional
304
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
305
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
306
307
    Returns
308
    -------
309
    ndarray
310
        Returns a condensed distance matrix. Collapses a square
311
        distance matrix into a vector just keeping the upper half. Use
312
        ``scipy.spatial.distance.squareform`` to convert to a square
313
        distance matrix.
314
    """
315
316
    if isinstance(pdds, np.ndarray):
317
        if len(pdds.shape) == 2:
318
            pdds = [pdds]
319
320
    m = len(pdds)
321
    cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
322
    if _VERBOSE:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE does not seem to be defined.
Loading history...
323
        eta = ETA((m * (m - 1)) // 2, update_rate=_VERBOSE_UPDATE_RATE)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE_UPDATE_RATE does not seem to be defined.
Loading history...
324
    inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable j does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
325
326
    for r, (i, j) in enumerate(inds):
327
        cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
328
        if _VERBOSE:
329
            eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case _VERBOSE on line 322 is False. Are you sure this can never be the case?
Loading history...
330
331
    return cdm
332
333
334
def emd(
335
        pdd: np.ndarray,
336
        pdd_: np.ndarray,
337
        metric: Optional[str] = 'chebyshev',
338
        return_transport: Optional[bool] = False, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
339
        **kwargs):
340
    """Alias for amd.emd()."""
341
    return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs)
342
343
344
def PDD_cdist_AMD_filter(
345
        n: int,
346
        pdds: List[np.ndarray],
347
        pdds_: Optional[List[np.ndarray]] = None,
348
        low_memory: bool = False,
349
        metric: str = 'chebyshev',
350
        **kwargs
351
) -> Tuple[np.ndarray, np.ndarray]:
352
    r"""For each item in ``pdds``, get the ``n`` nearest items in ``pdds_`` by AMD,
353
    then compare references to these nearest items with PDDs.
354
    Tries to comprimise between the speed of AMDs and the accuracy of PDDs.
355
356
    If ``pdds_`` is :const:`None`, this essentially sets ``pdds_ = pdds``, i.e.
357
    do an 'AMD neighbourhood graph' for one set whose weights are PDD distances.
358
359
    Parameters
360
    ----------
361
    n : int
362
        Number of nearest neighbours to find.
363
    pdds : list of ndarrays
364
        A list of PDDs.
365
    pdds\_ : list of ndarrays, optional
366
        A list of PDDs.
367
    low_memory : bool, optional
368
        Optionally use a slightly slower but more memory efficient method for
369
        large collections of AMDs (Chebyshev/l-inf distance only).
370
    metric : str or callable, optional
371
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
372
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
373
374
    Returns
375
    -------
376
    tuple of ndarrays (distance_matrix, indices)
377
        For the :math:`i` th item in reference and some :math:`j<n`,
378
        ``distance_matrix[i][j]`` is the distance from reference i to its j-th
379
        nearest neighbour in comparison (after the AMD filter).
380
        ``indices[i][j]`` is the index of said neighbour in ``pdds_``.
381
    """
382
383
    kwargs = {'metric': metric, **kwargs}
384
    amds = np.array([np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
385
                     for pdd in pdds])
386
387
    if low_memory:
388
        if metric != 'chebyshev':
389
            warnings.warn(
390
                "Using only allowed metric 'chebyshev' for low_memory",
391
                UserWarning)
392
        if pdds_ is None:
393
            inds = _amd_pdist_nns_low_memory(amds, n)
394
            pdds_ = pdds
395
        else:
396
            amds_ = np.array([np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
397
                              for pdd in pdds_])
398
            inds = _amd_cdist_nns_low_memory(amds, amds_, n) 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
399
    else:
400
        if pdds_ is None:
401
            inds = _amd_pdist_nns(amds, n, **kwargs)
402
        else:
403
            amd_dm = AMD_cdist(amds, amds_, low_memory=low_memory, **kwargs)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable amds_ does not seem to be defined.
Loading history...
404
            inds = np.array([np.argpartition(row, n)[:n] for row in amd_dm])
405
406
    dm, inds = _emd_nns_from_nn_inds(pdds, pdds_, inds, **kwargs)
407
    return dm, inds
408
409
410
def _amd_pdist_nns(amds, n, **kwargs):
411
    amd_cdm = AMD_pdist(amds, **kwargs)
412
    amd_dm = scipy.spatial.distance.squareform(amd_cdm)
413
    inds = []
414
    for i, row in enumerate(amd_dm):
415
        inds_row = np.argpartition(row, n+1)[:n+1]
416
        inds_row = inds_row[inds_row != i][:n]
417
        inds.append(inds_row)
418
    inds = np.array(inds)
419
    return inds
420
421
422
def _amd_pdist_nns_low_memory(amds, n):
423
    inds = []
424
    if _VERBOSE:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE does not seem to be defined.
Loading history...
425
        eta = ETA(len(amds), update_rate=_VERBOSE_UPDATE_RATE)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE_UPDATE_RATE does not seem to be defined.
Loading history...
426
    for i, amd_vec in enumerate(amds):
427
        dists = np.amax(np.abs(amds - amd_vec), axis=-1)
428
        inds_row = np.argpartition(dists, n+1)[:n+1]
429
        inds_row = inds_row[inds_row != i][:n]
430
        inds.append(inds_row)
431
        if _VERBOSE:
432
            eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case _VERBOSE on line 424 is False. Are you sure this can never be the case?
Loading history...
433
    return np.array(inds)
434
435
436
def _amd_cdist_nns_low_memory(amds, amds_, n):
437
    inds = []
438
    if _VERBOSE:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE does not seem to be defined.
Loading history...
439
        eta = ETA(len(amds) * len(amds_), update_rate=_VERBOSE_UPDATE_RATE)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE_UPDATE_RATE does not seem to be defined.
Loading history...
440
    for amd_vec in amds:
441
        row = np.amax(np.abs(amds_ - amd_vec), axis=-1)
442
        inds.append(np.argpartition(row, n)[:n])
443
        if _VERBOSE:
444
            eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case _VERBOSE on line 438 is False. Are you sure this can never be the case?
Loading history...
445
    return np.array(inds)
446
447
448
def _emd_nns_from_nn_inds(pdds, pdds_, inds, **kwargs):
449
    dm = np.empty(inds.shape)
450
    if _VERBOSE:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE does not seem to be defined.
Loading history...
451
        eta = ETA(inds.shape[0] * inds.shape[1], update_rate=_VERBOSE_UPDATE_RATE)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _VERBOSE_UPDATE_RATE does not seem to be defined.
Loading history...
452
453
    for i, row in enumerate(inds):
454
        for i_, j in enumerate(row):
455
            dm[i, i_] = EMD(pdds[i], pdds_[j], **kwargs)
456
            if _VERBOSE:
457
                eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case _VERBOSE on line 450 is False. Are you sure this can never be the case?
Loading history...
458
459
    sorted_inds = np.argsort(dm, axis=-1)
460
    inds = np.take_along_axis(inds, sorted_inds, axis=-1)
461
    dm = np.take_along_axis(dm, sorted_inds, axis=-1)
462
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
463
    return dm, inds
464