Passed
Push — master ( c9cb02...43c976 )
by Daniel
01:44
created

amd.compare.SDD_EMD()   C

Complexity

Conditions 10

Size

Total Lines 66
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 31
dl 0
loc 66
rs 5.9999
c 0
b 0
f 0
cc 10
nop 3

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd.compare.SDD_EMD() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Optional, Union
5
import warnings
6
7
import numpy as np
8
import scipy.spatial    # cdist, pdist, squareform
9
import scipy.optimize   # linear_sum_assignment
10
11
from ._network_simplex import network_simplex
12
from .utils import ETA
13
14
15
def EMD(
16
        pdd: np.ndarray,
17
        pdd_: np.ndarray,
18
        metric: Optional[str] = 'chebyshev',
19
        return_transport: Optional[bool] = False,
20
        **kwargs):
21
    r"""Earth mover's distance (EMD) between two PDDs, also known as
22
    the Wasserstein metric.
23
24
    Parameters
25
    ----------
26
    pdd : ndarray
27
        PDD of a crystal.
28
    pdd\_ : ndarray
29
        PDD of a crystal.
30
    metric : str or callable, optional
31
        EMD between PDDs requires defining a distance between rows of two PDDs.
32
        By default, Chebyshev/l-infinity distance is chosen as with AMDs.
33
        Can take any metric + ``kwargs`` accepted by
34
        ``scipy.spatial.distance.cdist``.
35
    return_transport: bool, optional
36
        Return a tuple (distance, transport_plan) with the optimal transport.
37
38
    Returns
39
    -------
40
    float
41
        Earth mover's distance between PDDs.
42
43
    Raises
44
    ------
45
    ValueError
46
        Thrown if the two PDDs do not have the
47
        same number of columns (``k`` value).
48
    """
49
50
    dm = scipy.spatial.distance.cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
51
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
52
53
    if return_transport:
54
        return emd_dist, transport_plan.reshape(dm.shape)
55
56
    return emd_dist
57
58
59
def AMD_cdist(
60
        amds: Union[np.ndarray, List[np.ndarray]],
61
        amds_: Union[np.ndarray, List[np.ndarray]],
62
        metric: str = 'chebyshev',
63
        low_memory: bool = False,
64
        **kwargs
65
) -> np.ndarray:
66
    r"""Compare two sets of AMDs with each other, returning a distance matrix.
67
68
    Parameters
69
    ----------
70
    amds : array_like
71
        A list of AMDs.
72
    amds\_ : array_like
73
        A list of AMDs.
74
    metric : str or callable, optional
75
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
76
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
77
    low_memory : bool, optional
78
        Use a slower but more memory efficient method for
79
        large collections of AMDs (Chebyshev/l-inf distance only).
80
81
    Returns
82
    -------
83
    ndarray
84
        A distance matrix shape ``(len(amds), len(amds_))``.
85
        The :math:`ij` th entry is the distance between ``amds[i]``
86
        and ``amds[j]`` given by the ``metric``.
87
    """
88
89
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
90
91
    if len(amds.shape) == 1:
92
        amds = np.array([amds])
93
    if len(amds_.shape) == 1:
94
        amds_ = np.array([amds_])
95
96
    if low_memory:
97
        if metric != 'chebyshev':
98
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
99
100
        dm = np.empty((len(amds), len(amds_)))
101
        for i, amd_vec in enumerate(amds):
102
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
103
    else:
104
        dm = scipy.spatial.distance.cdist(amds, amds_, metric=metric, **kwargs)
105
106
    return dm
107
108
109
def AMD_pdist(
110
        amds: Union[np.ndarray, List[np.ndarray]],
111
        metric: str = 'chebyshev',
112
        low_memory: bool = False,
113
        **kwargs
114
) -> np.ndarray:
115
    """Compare a set of AMDs pairwise, returning a condensed distance matrix.
116
117
    Parameters
118
    ----------
119
    amds : array_like
120
        An array/list of AMDs.
121
    metric : str or callable, optional
122
        Usually AMDs are compared with the Chebyshev/l-infinity distance.
123
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
124
    low_memory : bool, optional
125
        Optionally use a slightly slower but more memory efficient method for
126
        large collections of AMDs (Chebyshev/l-inf distance only).
127
128
    Returns
129
    -------
130
    ndarray
131
        Returns a condensed distance matrix. Collapses a square distance
132
        matrix into a vector just keeping the upper half. Use
133
        ``scipy.spatial.distance.squareform`` to convert to a square distance matrix.
134
    """
135
136
    amds = np.asarray(amds)
137
138
    if len(amds.shape) == 1:
139
        amds = np.array([amds])
140
141
    if low_memory:
142
        m = len(amds)
143
        if metric != 'chebyshev':
144
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
145
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
146
        ind = 0
147
        for i in range(m):
148
            ind_ = ind + m - i - 1
149
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
150
            ind = ind_
151
    else:
152
        cdm = scipy.spatial.distance.pdist(amds, metric=metric, **kwargs)
153
154
    return cdm
155
156
157
def PDD_cdist(
158
        pdds: List[np.ndarray],
159
        pdds_: List[np.ndarray],
160
        metric: str = 'chebyshev',
161
        verbose=False,
162
        **kwargs
163
) -> np.ndarray:
164
    r"""Compare two sets of PDDs with each other, returning a distance matrix.
165
166
    Parameters
167
    ----------
168
    pdds : list of ndarrays
169
        A list of PDDs.
170
    pdds\_ : list of ndarrays
171
        A list of PDDs.
172
    metric : str or callable, optional
173
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
174
        Can take any metric + kwargs accepted by
175
        ``scipy.spatial.distance.cdist``.
176
177
    Returns
178
    -------
179
    ndarray
180
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``.
181
        The :math:`ij` th entry is the distance between ``pdds[i]``
182
        and ``pdds_[j]`` given by Earth mover's distance.
183
    """
184
185
    if isinstance(pdds, np.ndarray):
186
        if len(pdds.shape) == 2:
187
            pdds = [pdds]
188
189
    if isinstance(pdds_, np.ndarray):
190
        if len(pdds_.shape) == 2:
191
            pdds_ = [pdds_]
192
193
    n, m = len(pdds), len(pdds_)
194
    dm = np.empty((n, m))
195
    if verbose:
196
        update_rate = (n * m) // 10000
197
        eta = ETA(n * m, update_rate=update_rate)
198
199
    for i in range(n):
200
        pdd = pdds[i]
201
        for j in range(m):
202
            dm[i, j] = EMD(pdd, pdds_[j], metric=metric, **kwargs)
203
            if verbose:
204
                eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case verbose on line 195 is False. Are you sure this can never be the case?
Loading history...
205
206
    return dm
207
208
209
def PDD_pdist(
210
        pdds: List[np.ndarray],
211
        metric: str = 'chebyshev',
212
        verbose=False,
213
        **kwargs
214
) -> np.ndarray:
215
    """Compare a set of PDDs pairwise, returning a condensed distance matrix.
216
217
    Parameters
218
    ----------
219
    pdds : list of ndarrays
220
        A list of PDDs.
221
    metric : str or callable, optional
222
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
223
        Can take any metric + kwargs accepted by ``scipy.spatial.distance.cdist``.
224
225
    Returns
226
    -------
227
    ndarray
228
        Returns a condensed distance matrix. Collapses a square
229
        distance matrix into a vector just keeping the upper half. Use
230
        ``scipy.spatial.distance.squareform`` to convert to a square
231
        distance matrix.
232
    """
233
234
    if isinstance(pdds, np.ndarray):
235
        if len(pdds.shape) == 2:
236
            pdds = [pdds]
237
238
    m = len(pdds)
239
    cdm_len = (m * (m - 1)) // 2
240
    cdm = np.empty(cdm_len, dtype=np.double)
241
    if verbose:
242
        update_rate = cdm_len // 10000
243
        eta = ETA(cdm_len, update_rate=update_rate)
244
    inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable j does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
245
246
    for r, (i, j) in enumerate(inds):
247
        cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
248
        if verbose:
249
            eta.update()
0 ignored issues
show
introduced by
The variable eta does not seem to be defined in case verbose on line 241 is False. Are you sure this can never be the case?
Loading history...
250
251
    return cdm
252
253
254
def emd(
255
        pdd: np.ndarray,
256
        pdd_: np.ndarray,
257
        metric: Optional[str] = 'chebyshev',
258
        return_transport: Optional[bool] = False,
259
        **kwargs):
260
    """Alias for amd.emd()."""
261
    return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs)
262