Passed
Push — master ( af43de...b2b3ff )
by Daniel
01:47
created

amd.compare   A

Complexity

Total Complexity 40

Size/Duplication

Total Lines 346
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 40
eloc 144
dl 0
loc 346
rs 9.2
c 0
b 0
f 0

8 Functions

Rating   Name   Duplication   Size   Complexity  
C SDD_EMD() 0 74 10
B AMD_pdist() 0 46 5
C PDD_cdist() 0 48 9
A set_verbose() 0 6 1
B PDD_pdist() 0 40 6
A EMD() 0 42 2
A emd() 0 8 1
B AMD_cdist() 0 48 6

How to fix   Complexity   

Complexity

Complex classes like amd.compare often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Tuple, Optional, Union
0 ignored issues
show
Unused Code introduced by
Unused Tuple imported from typing
Loading history...
5
import warnings
6
7
import numpy as np
8
import scipy.spatial    # cdist, pdist, squareform
9
import scipy.optimize   # linear_sum_assignment
10
11
from ._network_simplex import network_simplex
12
from .utils import ETA
13
14
15
_VERBOSE = False
16
_VERBOSE_UPDATE_RATE = 100
17
18
def set_verbose(setting, update_rate=100):
19
    """Pass True/False to turn on/off an ETA where relevant."""
20
    global _VERBOSE
0 ignored issues
show
Coding Style introduced by
Usage of the global statement should be avoided.

Usage of global can make code hard to read and test, its usage is generally not recommended unless you are dealing with legacy code.

Loading history...
21
    global _VERBOSE_UPDATE_RATE
0 ignored issues
show
Coding Style introduced by
Usage of the global statement should be avoided.

Usage of global can make code hard to read and test, its usage is generally not recommended unless you are dealing with legacy code.

Loading history...
22
    _VERBOSE = setting
23
    _VERBOSE_UPDATE_RATE = update_rate
24
25
set_verbose(False)
26
27
28
def EMD(
29
        pdd: np.ndarray,
30
        pdd_: np.ndarray,
31
        metric: Optional[str] = 'chebyshev',
32
        return_transport: Optional[bool] = False,
33
        **kwargs):
34
    r"""Earth mover's distance (EMD) between two PDDs, also known as
35
    the Wasserstein metric.
36
37
    Parameters
38
    ----------
39
    pdd : ndarray
40
        PDD of a crystal.
41
    pdd\_ : ndarray
42
        PDD of a crystal.
43
    metric : str or callable, optional
44
        EMD between PDDs requires defining a distance between rows of two PDDs.
45
        By default, Chebyshev/l-infinity distance is chosen as with AMDs.
46
        Can take any metric + ``kwargs`` accepted by
47
        ``scipy.spatial.distance.cdist``.
48
    return_transport: bool, optional
49
        Return a tuple (distance, transport_plan) with the optimal transport.
50
51
    Returns
52
    -------
53
    float
54
        Earth mover's distance between PDDs.
55
56
    Raises
57
    ------
58
    ValueError
59
        Thrown if the two PDDs do not have the
60
        same number of columns (``k`` value).
61
    """
62
63
    dm = scipy.spatial.distance.cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
64
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
65
66
    if return_transport:
67
        return emd_dist, transport_plan.reshape(dm.shape)
68
69
    return emd_dist
70
71
72
def SDD_EMD(sdd, sdd_, return_transport: Optional[bool] = False):
73
    r"""Earth mover's distance (EMD) between two SDDs.
74
75
    Parameters
76
    ----------
77
    sdd : tuple of ndarrays
78
        SDD of a crystal.
79
    sdd\_ : tuple of ndarrays
80
        SDD of a crystal.
81
    return_transport: bool, optional
82
        Return a tuple (distance, transport_plan) with the optimal transport.
83
84
    Returns
85
    -------
86
    float
87
        Earth mover's distance between SDDs.
88
89
    Raises
90
    ------
91
    ValueError
92
        Thrown if the two SDDs are not of the same order or do not have the
93
        same number of columns (``k`` value).
94
    """
95
96
    dists, dists_ = sdd[2], sdd_[2]
97
98
    # first order SDD, equivalent to PDD
99
    if dists.ndim == 2 and dists_.ndim == 2:
100
        dm = scipy.spatial.distance.cdist(dists, dists_, metric='chebyshev')
101
        emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
102
103
        if return_transport:
104
            return emd_dist, transport_plan.reshape(dm.shape)
105
106
        return emd_dist
107
108
    order = dists.shape[-1]
109
    n, m = len(sdd[0]), len(sdd_[0])
110
111
    dist_cdist = None
112
    if order == 2:
113
        dist_cdist = np.abs(sdd[1][:, None] - sdd_[1])
114
    else:
115
        dist, dist_ = sdd[1], sdd_[1]
116
117
        # take EMDs between finite PDDs in dist column
118
        weights = np.full((order, ), 1 / order)
119
        dist_cdist = np.empty((n, m), dtype=np.float64)
120
        for i in range(n):
121
            for j in range(m):
122
                finite_pdd_dm = scipy.spatial.distance.cdist(dist[i], dist_[j], metric='chebyshev')
123
                dists_emd, _ = network_simplex(weights, weights, finite_pdd_dm)
124
                dist_cdist[i, j] = dists_emd
125
126
        # flatten and compare by linf
127
        # flat_dist = dist.reshape((n, order * (order - 1)))
128
        # flat_dist_ = dist_.reshape((m, order * (order - 1)))
129
        # flat_dist = np.sort(flat_dist, axis=-1)
130
        # flat_dist_ = np.sort(flat_dist_, axis=-1)
131
        # dist_cdist = scipy.spatial.distance.cdist(flat_dist, flat_dist_, metric='chebyshev')
132
133
    dm = np.empty((n, m), dtype=np.float64)
134
    for i in range(n):
135
        for j in range(m):
136
            cost_matrix = scipy.spatial.distance.cdist(dists[i], dists_[j], metric='chebyshev')
137
            row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
138
            dm[i, j] = max(np.amax(cost_matrix[row_ind, col_ind]), dist_cdist[i, j])
139
140
    emd_dist, transport_plan = network_simplex(sdd[0], sdd_[0], dm)
141
142
    if return_transport:
143
        return emd_dist, transport_plan.reshape(dm.shape)
144
145
    return emd_dist
146
147
148
def AMD_cdist(
149
        amds: Union[np.ndarray, List[np.ndarray]],
150
        amds_: Union[np.ndarray, List[np.ndarray]],
151
        metric: str = 'chebyshev',
152
        low_memory: bool = False,
153
        **kwargs
154
) -> np.ndarray:
155
    r"""Compare two sets of AMDs with each other, returning a distance matrix.
156
157
    Parameters
158
    ----------
159
    amds : array_like
160
        A list of AMDs.
161
    amds\_ : array_like
162
        A list of AMDs.
163
    metric : str or callable, optional
164
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
165
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
166
    low_memory : bool, optional
167
        Use a slower but more memory efficient method for
168
        large collections of AMDs (Chebyshev/l-inf distance only).
169
170
    Returns
171
    -------
172
    ndarray
173
        A distance matrix shape ``(len(amds), len(amds_))``.
174
        The :math:`ij` th entry is the distance between ``amds[i]``
175
        and ``amds[j]`` given by the ``metric``.
176
    """
177
178
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
179
180
    if len(amds.shape) == 1:
181
        amds = np.array([amds])
182
    if len(amds_.shape) == 1:
183
        amds_ = np.array([amds_])
184
185
    if low_memory:
186
        if metric != 'chebyshev':
187
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
188
189
        dm = np.empty((len(amds), len(amds_)))
190
        for i, amd_vec in enumerate(amds):
191
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
192
    else:
193
        dm = scipy.spatial.distance.cdist(amds, amds_, metric=metric, **kwargs)
194
195
    return dm
196
197
198
def AMD_pdist(
199
        amds: Union[np.ndarray, List[np.ndarray]],
200
        metric: str = 'chebyshev',
201
        low_memory: bool = False,
202
        **kwargs
203
) -> np.ndarray:
204
    """Compare a set of AMDs pairwise, returning a condensed distance matrix.
205
206
    Parameters
207
    ----------
208
    amds : array_like
209
        An array/list of AMDs.
210
    metric : str or callable, optional
211
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
212
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
213
    low_memory : bool, optional
214
        Optionally use a slightly slower but more memory efficient method for
215
        large collections of AMDs (Chebyshev/l-inf distance only).
216
217
    Returns
218
    -------
219
    ndarray
220
        Returns a condensed distance matrix. Collapses a square distance
221
        matrix into a vector just keeping the upper half. Use
222
        ``scipy.spatial.distance.squareform`` to convert to a square distance matrix.
223
    """
224
225
    amds = np.asarray(amds)
226
227
    if len(amds.shape) == 1:
228
        amds = np.array([amds])
229
230
    if low_memory:
231
        m = len(amds)
232
        if metric != 'chebyshev':
233
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
234
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
235
        ind = 0
236
        for i in range(m):
237
            ind_ = ind + m - i - 1
238
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
239
            ind = ind_
240
    else:
241
        cdm = scipy.spatial.distance.pdist(amds, metric=metric, **kwargs)
242
243
    return cdm
244
245
246
def PDD_cdist(
247
        pdds: List[np.ndarray],
248
        pdds_: List[np.ndarray],
249
        metric: str = 'chebyshev',
250
        **kwargs
251
) -> np.ndarray:
252
    r"""Compare two sets of PDDs with each other, returning a distance matrix.
253
254
    Parameters
255
    ----------
256
    pdds : list of ndarrays
257
        A list of PDDs.
258
    pdds\_ : list of ndarrays
259
        A list of PDDs.
260
    metric : str or callable, optional
261
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
262
        Can take any metric + kwargs accepted by
263
        ``scipy.spatial.distance.cdist``.
264
265
    Returns
266
    -------
267
    ndarray
268
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``.
269
        The :math:`ij` th entry is the distance between ``pdds[i]``
270
        and ``pdds_[j]`` given by Earth mover's distance.
271
    """
272
273
    if isinstance(pdds, np.ndarray):
274
        if len(pdds.shape) == 2:
275
            pdds = [pdds]
276
277
    if isinstance(pdds_, np.ndarray):
278
        if len(pdds_.shape) == 2:
279
            pdds_ = [pdds_]
280
281
    n, m = len(pdds), len(pdds_)
282
    dm = np.empty((n, m))
283
    if _VERBOSE:
284
        eta = ETA(n * m, update_rate=_VERBOSE_UPDATE_RATE)
285
286
    for i in range(n):
287
        pdd = pdds[i]
288
        for j in range(m):
289
            dm[i, j] = EMD(pdd, pdds_[j], metric=metric, **kwargs)
290
            if _VERBOSE:
291
                eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case _VERBOSE on line 283 is False. Are you sure this can never be the case?
Loading history...
292
293
    return dm
294
295
296
def PDD_pdist(
297
        pdds: List[np.ndarray],
298
        metric: str = 'chebyshev',
299
        **kwargs
300
) -> np.ndarray:
301
    """Compare a set of PDDs pairwise, returning a condensed distance matrix.
302
303
    Parameters
304
    ----------
305
    pdds : list of ndarrays
306
        A list of PDDs.
307
    metric : str or callable, optional
308
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
309
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
310
311
    Returns
312
    -------
313
    ndarray
314
        Returns a condensed distance matrix. Collapses a square
315
        distance matrix into a vector just keeping the upper half. Use
316
        ``scipy.spatial.distance.squareform`` to convert to a square
317
        distance matrix.
318
    """
319
320
    if isinstance(pdds, np.ndarray):
321
        if len(pdds.shape) == 2:
322
            pdds = [pdds]
323
324
    m = len(pdds)
325
    cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
326
    if _VERBOSE:
327
        eta = ETA((m * (m - 1)) // 2, update_rate=_VERBOSE_UPDATE_RATE)
328
    inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable j does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
329
330
    for r, (i, j) in enumerate(inds):
331
        cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
332
        if _VERBOSE:
333
            eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case _VERBOSE on line 326 is False. Are you sure this can never be the case?
Loading history...
334
335
    return cdm
336
337
338
def emd(
339
        pdd: np.ndarray,
340
        pdd_: np.ndarray,
341
        metric: Optional[str] = 'chebyshev',
342
        return_transport: Optional[bool] = False,
343
        **kwargs):
344
    """Alias for amd.emd()."""
345
    return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs)
346