Passed
Push — master ( 6f99a2...0038b7 )
by Daniel
09:15 queued 06:11
created

amd.compare._unwrap_periodicset_list()   A

Complexity

Conditions 3

Size

Total Lines 11
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 3
nop 2
1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Optional, Union
5
import warnings
6
from itertools import combinations
7
from functools import partial
8
import os
9
10
import numpy as np
11
from scipy.spatial.distance import cdist, pdist, squareform
12
from joblib import Parallel, delayed
13
import pandas as pd
14
15
from ._emd import network_simplex
16
from .periodicset import PeriodicSet
17
from .calculate import AMD, PDD
18
from .io import CifReader, CSDReader
19
20
21
def compare(crystals, crystals_=None, by='AMD', k=100, **kwargs):
22
    r"""Given one or two sets of periodic set(s), refcode(s) or cif(s), compare them
23
    returning a DataFrame of the distance matrix. Default is to comapre by PDD
24
    with k=100. Accepts most keyword arguments accepted by the CifReader, CSDReader
25
    and compare functions, for a full list see the documentation Quick Start page.
26
    Note that using refcodes requires csd-python-api.
27
28
    Parameters
29
    ----------
30
    crystals : array or list of arrays
31
        One or a collection of paths, refcodes, file objects or :class:`.periodicset.PeriodicSet` s.
32
    crystals\_ : array or list of arrays, optional
33
        One or a collection of paths, refcodes, file objects or :class:`.periodicset.PeriodicSet` s.
34
    by : str, default 'AMD'
35
        Invariant to compare by, either 'AMD' or 'PDD'.
36
    k: int, default 100
37
        k value to use for the invariants (length of AMD, or number of columns in PDD).
38
39
    Returns
40
    -------
41
    df : pandas.DataFrame
42
        DataFrame of the distance matrix for the given crystals compared by the chosen invariant.
43
44
    Raises
45
    ------
46
    ValueError
47
        If by is not 'AMD' or 'PDD', if either set given have no valid crystals
48
        to compare, or if crystals or crystals\_ are an invalid type.
49
50
    Examples
51
    --------
52
    Compare everything in a .cif (deafult, AMD with k=100)::
53
54
        df = amd.compare('data.cif')
55
56
    Compare everything in one cif with all crystals in all cifs in a directory (PDD, k=50)::
57
58
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
59
60
    **Examples (csd-python-api only)**
61
62
    Compare two crystals by CSD refcode (PDD, k=50)::
63
64
        df = amd.compare('DEBXIT01', 'DEBXIT02', by='PDD', k=50)
65
66
    Compare everything in a refcode family (AMD, k=100)::
67
68
        df = amd.compare('DEBXIT', families=True)
69
    """
70
71
    by = by.upper()
72
    if by not in ('AMD', 'PDD'):
73
        raise ValueError(f"parameter 'by' in compare accepts 'AMD' or 'PDD', was passed {by}")
74
75
    reader_kwargs = {
76
        'reader': 'ase',
77
        'families': False,
78
        'remove_hydrogens': False,
79
        'disorder': 'skip',
80
        'heaviest_component': False,
81
        'show_warnings': True,
82
    }
83
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
84
    calc_kwargs = {
85
        'collapse': True,
86
        'collapse_tol': 1e-4,
87
    }
88
89
    compare_kwargs = {
90
        'metric': 'chebyshev',
91
        'n_jobs': None,
92
        'verbose': 0,
93
        'low_memory': False,
94
    }
95
96
    for key in kwargs.keys() & calc_kwargs.keys():
97
        calc_kwargs[key] = kwargs[key]
98
        del kwargs[key]
99
100
    for key in kwargs.keys() & reader_kwargs.keys():
101
        reader_kwargs[key] = kwargs[key]
102
        del kwargs[key]
103
  
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
104
    compare_kwargs.update(kwargs)
105
106
    crystals = _unwrap_periodicset_list(crystals, **reader_kwargs)
107
    if not crystals:
108
        raise ValueError('No valid crystals to compare in first set passed.')
109
110
    if by == 'AMD':
111
        invs = [AMD(s, k) for s in crystals]
112
        compare_kwargs.pop('n_jobs', None)
113
        compare_kwargs.pop('verbose', None)
114
    elif by == 'PDD':
115
        invs = [PDD(s, k, lexsort=False, **calc_kwargs) for s in crystals]
116
        compare_kwargs.pop('low_memory', None)
117
118
    names = [s.name for s in crystals]
119
120
    if crystals_ is None:
121
        names_ = names
122
        if by == 'AMD':
123
            dm = squareform(AMD_pdist(invs, **compare_kwargs))
124
        elif by == 'PDD':
125
            dm = squareform(PDD_pdist(invs, **compare_kwargs))
126
    else:
127
        crystals_ = _unwrap_periodicset_list(crystals_, **reader_kwargs)
128
        if not crystals_:
129
            raise ValueError('No valid crystals to compare in second set passed.')
130
        names_ = [s.name for s in crystals_]
131
132
        if by == 'AMD':
133
            invs_ = [AMD(s, k) for s in crystals_]
134
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
135
        elif by == 'PDD':
136
            invs_ = [PDD(s, k) for s in crystals_]
137
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
138
139
    return pd.DataFrame(dm, index=names, columns=names_)
140
141
142
def EMD(
143
        pdd: np.ndarray,
144
        pdd_: np.ndarray,
145
        metric: Optional[str] = 'chebyshev',
146
        return_transport: Optional[bool] = False,
147
        **kwargs):
148
    r"""Earth mover's distance (EMD) between two PDDs, also known as
149
    the Wasserstein metric.
150
151
    Parameters
152
    ----------
153
    pdd : numpy.ndarray
154
        PDD of a crystal.
155
    pdd\_ : numpy.ndarray
156
        PDD of a crystal.
157
    metric : str or callable, default 'chebyshev'
158
        EMD between PDDs requires defining a distance between PDD rows.
159
        By default, Chebyshev (L-infinity) distance is chosen as with AMDs.
160
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
161
    return_transport: bool, default False
162
        Return a tuple ``(distance, transport_plan)`` with the optimal transport.
163
164
    Returns
165
    -------
166
    emd : float
167
        Earth mover's distance between two PDDs.
168
169
    Raises
170
    ------
171
    ValueError
172
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
173
        columns (``k`` value).
174
    """
175
176
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
177
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
178
179
    if return_transport:
180
        return emd_dist, transport_plan
181
182
    return emd_dist
183
184
185
def AMD_cdist(
186
        amds: Union[np.ndarray, List[np.ndarray]],
187
        amds_: Union[np.ndarray, List[np.ndarray]],
188
        metric: str = 'chebyshev',
189
        low_memory: bool = False,
190
        **kwargs
191
) -> np.ndarray:
192
    r"""Compare two sets of AMDs with each other, returning a distance matrix.
193
    This function is essentially identical to :func:`scipy.spatial.distance.cdist`
194
    with the default metric ``chebyshev``.
195
196
    Parameters
197
    ----------
198
    amds : array_like
199
        A list of AMDs.
200
    amds\_ : array_like
201
        A list of AMDs.
202
    metric : str or callable, default 'chebyshev'
203
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
204
        Can take any metric accepted by :func:`scipy.spatial.distance.cdist`.
205
    low_memory : bool, default False
206
        Use a slower but more memory efficient method for
207
        large collections of AMDs (Chebyshev metric only).
208
209
    Returns
210
    -------
211
    dm : numpy.ndarray
212
        A distance matrix shape ``(len(amds), len(amds_))``.
213
        ``dm[ij]`` is the distance (given by ``metric``) 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
214
        between ``amds[i]`` and ``amds[j]``.
215
    """
216
217
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
218
219
    if len(amds.shape) == 1:
220
        amds = np.array([amds])
221
    if len(amds_.shape) == 1:
222
        amds_ = np.array([amds_])
223
224
    if low_memory:
225
        if metric != 'chebyshev':
226
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
227
228
        dm = np.empty((len(amds), len(amds_)))
229
        for i, amd_vec in enumerate(amds):
230
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
231
    else:
232
        dm = cdist(amds, amds_, metric=metric, **kwargs)
233
234
    return dm
235
236
237
def AMD_pdist(
238
        amds: Union[np.ndarray, List[np.ndarray]],
239
        metric: str = 'chebyshev',
240
        low_memory: bool = False,
241
        **kwargs
242
) -> np.ndarray:
243
    """Compare a set of AMDs pairwise, returning a condensed distance matrix.
244
    This function is essentially identical to :func:`scipy.spatial.distance.pdist`
245
    with the default metric ``chebyshev``.
246
247
    Parameters
248
    ----------
249
    amds : array_like
250
        An array/list of AMDs.
251
    metric : str or callable, default 'chebyshev'
252
        Usually AMDs are compared with the Chebyshev (L-infinity) distance.
253
        Can take any metric accepted by :func:`scipy.spatial.distance.pdist`.
254
    low_memory : bool, default False
255
        Optionally use a slightly slower but more memory efficient method for
256
        large collections of AMDs (Chebyshev metric only).
257
258
    Returns
259
    -------
260
    numpy.ndarray
261
        Returns a condensed distance matrix. Collapses a square distance
262
        matrix into a vector, just keeping the upper half. See
263
        :func:`scipy.spatial.distance.squareform` to convert to a square 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
264
        distance matrix or for more on condensed distance matrices.
265
    """
266
267
    amds = np.asarray(amds)
268
269
    if len(amds.shape) == 1:
270
        amds = np.array([amds])
271
272
    if low_memory:
273
        m = len(amds)
274
        if metric != 'chebyshev':
275
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
276
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
277
        ind = 0
278
        for i in range(m):
279
            ind_ = ind + m - i - 1
280
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
281
            ind = ind_
282
    else:
283
        cdm = pdist(amds, metric=metric, **kwargs)
284
285
    return cdm
286
287
288
def PDD_cdist(
289
        pdds: List[np.ndarray],
290
        pdds_: List[np.ndarray],
291
        metric: str = 'chebyshev',
292
        n_jobs=None,
293
        verbose=0,
294
        **kwargs
295
) -> np.ndarray:
296
    r"""Compare two sets of PDDs with each other, returning a distance matrix.
297
298
    Parameters
299
    ----------
300
    pdds : List[numpy.ndarray]
301
        A list of PDDs.
302
    pdds\_ : List[numpy.ndarray]
303
        A list of PDDs.
304
    metric : str or callable, default 'chebyshev'
305
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
306
        Can take any metric accepted by :func:`scipy.spatial.distance.cdist`.
307
    n_jobs : int, default None
308
        Maximum number of concurrent jobs for parallel processing with joblib.
309
        Set to -1 to use the maximum possible. Note that for small inputs (< 100), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
310
        using parallel processing may be slower than the default n_jobs=None.
311
    verbose : int, default 0
312
        The verbosity level. Higher = more verbose, see joblib.Parallel.
313
314
    Returns
315
    -------
316
    numpy.ndarray
317
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``.
318
        The :math:`ij` th entry is the distance between ``pdds[i]``
319
        and ``pdds_[j]`` given by Earth mover's distance.
320
    """
321
322
    if isinstance(pdds, np.ndarray):
323
        if len(pdds.shape) == 2:
324
            pdds = [pdds]
325
326
    if isinstance(pdds_, np.ndarray):
327
        if len(pdds_.shape) == 2:
328
            pdds_ = [pdds_]
329
330
    kwargs.pop('return_transport', None)
331
332
    # TODO: put results into preallocated empty array in place
0 ignored issues
show
Coding Style introduced by
TODO and FIXME comments should generally be avoided.
Loading history...
333
    dm = Parallel(backend='multiprocessing', n_jobs=n_jobs, verbose=verbose)(
334
        delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
335
        for j in range(len(pdds_)) for i in range(len(pdds))
336
    )
337
    dm = np.array(dm)
338
    return dm.reshape((len(pdds), len(pdds_)))
339
340
341
def PDD_pdist(
342
        pdds: List[np.ndarray],
343
        metric: str = 'chebyshev',
344
        n_jobs=None,
345
        verbose=0,
346
        **kwargs
347
) -> np.ndarray:
348
    """Compare a set of PDDs pairwise, returning a condensed distance matrix.
349
350
    Parameters
351
    ----------
352
    pdds : List[numpy.ndarray]
353
        A list of PDDs.
354
    metric : str or callable, default 'chebyshev'
355
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
356
        Can take any metric accepted by :func:`scipy.spatial.distance.pdist`.
357
    n_jobs : int, default None
358
        Maximum number of concurrent jobs for parallel processing with joblib.
359
        Set to -1 to use the maximum possible. Note that for small inputs (< 100), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
360
        using parallel processing may be slower than the default n_jobs=None.
361
    verbose : int, default 0
362
        The verbosity level. Higher = more verbose, see joblib.Parallel for more.
363
364
    Returns
365
    -------
366
    numpy.ndarray
367
        Returns a condensed distance matrix. Collapses a square
368
        distance matrix into a vector just keeping the upper half. See
369
        :func:`scipy.spatial.distance.squareform` to convert to a square 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
370
        distance matrix or for more on condensed distance matrices.
371
    """
372
373
    kwargs.pop('return_transport', None)
374
375
    # TODO: put results into preallocated empty array in place
0 ignored issues
show
Coding Style introduced by
TODO and FIXME comments should generally be avoided.
Loading history...
376
    dm = Parallel(n_jobs=n_jobs, verbose=verbose)(
377
        delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
378
        for i, j in combinations(range(len(pdds)), 2)
379
    )
380
    return np.array(dm)
381
382
def emd(
383
        pdd: np.ndarray,
384
        pdd_: np.ndarray,
385
        metric: Optional[str] = 'chebyshev',
386
        return_transport: Optional[bool] = False,
387
        **kwargs):
388
    """Alias for amd.EMD()."""
389
    return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs)
390
391
392
def _unwrap_periodicset_list(psets_or_str, **reader_kwargs):
393
    """Valid input for compare (PeriodicSet, path, refcode, lists of such)
394
    --> 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
395
    list of PeriodicSets"""
396
397
    if isinstance(psets_or_str, PeriodicSet):
0 ignored issues
show
unused-code introduced by
Unnecessary "elif" after "return"
Loading history...
398
        return [psets_or_str]
399
    elif isinstance(psets_or_str, list):
400
        return [s for item in psets_or_str for s in _extract_periodicsets(item, **reader_kwargs)]
401
    else:
402
        return _extract_periodicsets(psets_or_str, **reader_kwargs)
403
404
405
def _extract_periodicsets(item, **reader_kwargs):
406
    """str (path/refocde), file or PeriodicSet --> list of PeriodicSets."""
407
408
    if isinstance(item, PeriodicSet):
0 ignored issues
show
unused-code introduced by
Unnecessary "elif" after "return"
Loading history...
409
        return [item]
410
    elif isinstance(item, str) and not os.path.isfile(item) and not os.path.isdir(item):
411
        reader_kwargs.pop('reader', None)
412
        return list(CSDReader(item, **reader_kwargs))
413
    else:
414
        reader_kwargs.pop('families', None)
415
        reader_kwargs.pop('refcodes', None)
416
        return list(CifReader(item, **reader_kwargs))
417