Test Failed
Push — master ( 0038b7...1ed659 )
by Daniel
17:22 queued 14:23
created

amd.compare.emd()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 5
1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
import warnings
5
from typing import List, Optional, Union
6
from functools import partial
7
from itertools import combinations
8
import os
9
10
import numpy as np
11
import pandas as pd
12
from scipy.spatial.distance import cdist, pdist, squareform
13
from joblib import Parallel, delayed
14
from progressbar import ProgressBar
15
16
from .io import CifReader, CSDReader
17
from .calculate import AMD, PDD
18
from ._emd import network_simplex
19
from .periodicset import PeriodicSet
20
21
22
def compare(
23
        crystals,
24
        crystals_=None,
25
        by='AMD',
26
        k=100,
27
        **kwargs
28
) -> pd.DataFrame:
29
    r"""Given one or two sets of periodic set(s), refcode(s) or cif(s), compare them
30
    returning a DataFrame of the distance matrix. Default is to comapre by PDD
31
    with k=100. Accepts most keyword arguments accepted by the CifReader, CSDReader
32
    and compare functions, for a full list see the documentation Quick Start page.
33
    Note that using refcodes requires csd-python-api.
34
35
    Parameters
36
    ----------
37
    crystals : array or list of arrays
38
        One or a collection of paths, refcodes, file objects or :class:`.periodicset.PeriodicSet` s.
39
    crystals\_ : array or list of arrays, optional
40
        One or a collection of paths, refcodes, file objects or :class:`.periodicset.PeriodicSet` s.
41
    by : str, default 'AMD'
42
        Invariant to compare by, either 'AMD' or 'PDD'.
43
    k : int, default 100
44
        k value to use for the invariants (length of AMD, or number of columns in PDD).
45
46
    Returns
47
    -------
48
    df : pandas.DataFrame
49
        DataFrame of the distance matrix for the given crystals compared by the chosen invariant.
50
51
    Raises
52
    ------
53
    ValueError
54
        If by is not 'AMD' or 'PDD', if either set given have no valid crystals
55
        to compare, or if crystals or crystals\_ are an invalid type.
56
57
    Examples
58
    --------
59
    Compare everything in a .cif (deafult, AMD with k=100)::
60
61
        df = amd.compare('data.cif')
62
63
    Compare everything in one cif with all crystals in all cifs in a directory (PDD, k=50)::
64
65
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
66
67
    **Examples (csd-python-api only)**
68
69
    Compare two crystals by CSD refcode (PDD, k=50)::
70
71
        df = amd.compare('DEBXIT01', 'DEBXIT02', by='PDD', k=50)
72
73
    Compare everything in a refcode family (AMD, k=100)::
74
75
        df = amd.compare('DEBXIT', families=True)
76
    """
77
78
    by = by.upper()
79
    if by not in ('AMD', 'PDD'):
80
        raise ValueError(f"parameter 'by' in compare accepts 'AMD' or 'PDD', was passed {by}")
81
82
    reader_kwargs = {
83
        'reader': 'ase',
84
        'families': False,
85
        'remove_hydrogens': False,
86
        'disorder': 'skip',
87
        'heaviest_component': False,
88
        'molecular_centres': False,
89
        'show_warnings': True,
90
    }
91
92
    calc_kwargs = {
93
        'collapse': True,
94
        'collapse_tol': 1e-4,
95
        'lexsort': False,
96
    }
97
98
    compare_kwargs = {
99
        'metric': 'chebyshev',
100
        'n_jobs': None,
101
        'verbose': 0,
102
        'low_memory': False,
103
    }
104
105
    for default_kwargs in (reader_kwargs, calc_kwargs, compare_kwargs):
106
        for key in kwargs.keys() & default_kwargs.keys():
107
            default_kwargs[key] = kwargs[key]
108
109
    crystals = _unwrap_periodicset_list(crystals, **reader_kwargs)
110
    if not crystals:
111
        raise ValueError('No valid crystals to compare in first set.')
112
    names = [s.name for s in crystals]
113
114
    if crystals_ is None:
115
        names_ = names
116
    else:
117
        crystals_ = _unwrap_periodicset_list(crystals_, **reader_kwargs)
118
        if not crystals_:
119
            raise ValueError('No valid crystals to compare in second set.')
120
        names_ = [s.name for s in crystals_]
121
122
    if reader_kwargs['molecular_centres']:
123
        crystals = [(c.molecular_centres, c.cell) for c in crystals]
124
        if crystals_:
125
            crystals_ = [(c.molecular_centres, c.cell) for c in crystals_]
126
127
    if by == 'AMD':
128
129
        invs = [AMD(s, k) for s in crystals]
130
        compare_kwargs.pop('n_jobs', None)
131
        compare_kwargs.pop('verbose', None)
132
133
        if crystals_ is None:
134
            dm = squareform(AMD_pdist(invs, **compare_kwargs))
135
        else:
136
            invs_ = [AMD(s, k) for s in crystals_]
137
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
138
139
    elif by == 'PDD':
140
141
        invs = [PDD(s, k, **calc_kwargs) for s in crystals]
142
        compare_kwargs.pop('low_memory', None)
143
144
        if crystals_ is None:
145
            dm = squareform(PDD_pdist(invs, **compare_kwargs))
146
        else:
147
            invs_ = [PDD(s, k) for s in crystals_]
148
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
149
150
    return pd.DataFrame(dm, index=names, columns=names_)
151
152
153
def EMD(
154
        pdd: np.ndarray,
155
        pdd_: np.ndarray,
156
        metric: Optional[str] = 'chebyshev',
157
        return_transport: Optional[bool] = False,
158
        **kwargs):
159
    r"""Earth mover's distance (EMD) between two PDDs, also known as
160
    the Wasserstein metric.
161
162
    Parameters
163
    ----------
164
    pdd : numpy.ndarray
165
        PDD of a crystal.
166
    pdd\_ : numpy.ndarray
167
        PDD of a crystal.
168
    metric : str or callable, default 'chebyshev'
169
        EMD between PDDs requires defining a distance between PDD rows.
170
        By default, Chebyshev (L-infinity) distance is chosen as with AMDs.
171
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
172
    return_transport: bool, default False
173
        Return a tuple ``(distance, transport_plan)`` with the optimal transport.
174
175
    Returns
176
    -------
177
    emd : float
178
        Earth mover's distance between two PDDs.
179
180
    Raises
181
    ------
182
    ValueError
183
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
184
        columns (``k`` value).
185
    """
186
187
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
188
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
189
190
    if return_transport:
191
        return emd_dist, transport_plan
192
193
    return emd_dist
194
195
196
def AMD_cdist(
197
        amds: Union[np.ndarray, List[np.ndarray]],
198
        amds_: Union[np.ndarray, List[np.ndarray]],
199
        metric: str = 'chebyshev',
200
        low_memory: bool = False,
201
        **kwargs
202
) -> np.ndarray:
203
    r"""Compare two sets of AMDs with each other, returning a distance matrix.
204
    This function is essentially identical to :func:`scipy.spatial.distance.cdist`
205
    with the default metric ``chebyshev``.
206
207
    Parameters
208
    ----------
209
    amds : array_like
210
        A list of AMDs.
211
    amds\_ : array_like
212
        A list of AMDs.
213
    metric : str or callable, default 'chebyshev'
214
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
215
        Can take any metric accepted by :func:`scipy.spatial.distance.cdist`.
216
    low_memory : bool, default False
217
        Use a slower but more memory efficient method for
218
        large collections of AMDs (Chebyshev metric only).
219
220
    Returns
221
    -------
222
    dm : numpy.ndarray
223
        A distance matrix shape ``(len(amds), len(amds_))``.
224
        ``dm[ij]`` is the distance (given by ``metric``) 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
225
        between ``amds[i]`` and ``amds[j]``.
226
    """
227
228
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
229
230
    if len(amds.shape) == 1:
231
        amds = np.array([amds])
232
    if len(amds_.shape) == 1:
233
        amds_ = np.array([amds_])
234
235
    if low_memory:
236
        if metric != 'chebyshev':
237
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
238
239
        dm = np.empty((len(amds), len(amds_)))
240
        for i, amd_vec in enumerate(amds):
241
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
242
    else:
243
        dm = cdist(amds, amds_, metric=metric, **kwargs)
244
245
    return dm
246
247
248
def AMD_pdist(
249
        amds: Union[np.ndarray, List[np.ndarray]],
250
        metric: str = 'chebyshev',
251
        low_memory: bool = False,
252
        **kwargs
253
) -> np.ndarray:
254
    """Compare a set of AMDs pairwise, returning a condensed distance matrix.
255
    This function is essentially identical to :func:`scipy.spatial.distance.pdist`
256
    with the default metric ``chebyshev``.
257
258
    Parameters
259
    ----------
260
    amds : array_like
261
        An array/list of AMDs.
262
    metric : str or callable, default 'chebyshev'
263
        Usually AMDs are compared with the Chebyshev (L-infinity) distance.
264
        Can take any metric accepted by :func:`scipy.spatial.distance.pdist`.
265
    low_memory : bool, default False
266
        Optionally use a slightly slower but more memory efficient method for
267
        large collections of AMDs (Chebyshev metric only).
268
269
    Returns
270
    -------
271
    numpy.ndarray
272
        Returns a condensed distance matrix. Collapses a square distance
273
        matrix into a vector, just keeping the upper half. See
274
        :func:`scipy.spatial.distance.squareform` to convert to a square 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
275
        distance matrix or for more on condensed distance matrices.
276
    """
277
278
    amds = np.asarray(amds)
279
280
    if len(amds.shape) == 1:
281
        amds = np.array([amds])
282
283
    if low_memory:
284
        m = len(amds)
285
        if metric != 'chebyshev':
286
            warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning)
287
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
288
        ind = 0
289
        for i in range(m):
290
            ind_ = ind + m - i - 1
291
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
292
            ind = ind_
293
    else:
294
        cdm = pdist(amds, metric=metric, **kwargs)
295
296
    return cdm
297
298
299
def PDD_cdist(
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
300
        pdds: List[np.ndarray],
301
        pdds_: List[np.ndarray],
302
        metric: str = 'chebyshev',
303
        backend='multiprocessing',
304
        n_jobs=None,
305
        verbose=0,
306
        **kwargs
307
) -> np.ndarray:
308
    r"""Compare two sets of PDDs with each other, returning a distance matrix.
309
    Supports parallel processing via joblib. If using parallelisation, make sure to
310
    include a if __name__ == '__main__' guard around this function.
311
312
    Parameters
313
    ----------
314
    pdds : List[numpy.ndarray]
315
        A list of PDDs.
316
    pdds\_ : List[numpy.ndarray]
317
        A list of PDDs.
318
    metric : str or callable, default 'chebyshev'
319
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
320
        Can take any metric accepted by :func:`scipy.spatial.distance.cdist`.
321
    n_jobs : int, default None
322
        Maximum number of concurrent jobs for parallel processing with joblib.
323
        Set to -1 to use the maximum possible. Note that for small inputs (< 100), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
324
        using parallel processing may be slower than the default n_jobs=None.
325
    verbose : int, default 0
326
        Controls verbosity. If using parallel processing (n_jobs > 1), verbose is
327
        passed to :class:`joblib.Parallel`, where larger values = more verbosity.
328
        Otherwise, uses progressbar2 where the progressbar is either on or off.
329
    backend : str, default 'multiprocessing'
330
        Specifies the parallelization backend implementation. For a list of
331
        supported backends, see the backend argument of :class:`joblib.Parallel`.
332
333
    Returns
334
    -------
335
    numpy.ndarray
336
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``.
337
        The :math:`ij` th entry is the distance between ``pdds[i]``
338
        and ``pdds_[j]`` given by Earth mover's distance.
339
    """
340
341
    if isinstance(pdds, np.ndarray):
342
        if len(pdds.shape) == 2:
343
            pdds = [pdds]
344
345
    if isinstance(pdds_, np.ndarray):
346
        if len(pdds_.shape) == 2:
347
            pdds_ = [pdds_]
348
349
    kwargs.pop('return_transport', None)
350
351
    if n_jobs is not None and n_jobs > 1:
352
        # TODO: put results into preallocated empty array in place
0 ignored issues
show
Coding Style introduced by
TODO and FIXME comments should generally be avoided.
Loading history...
353
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
354
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
355
            for i in range(len(pdds)) for j in range(len(pdds_))
356
        )
357
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
358
359
    else:
360
        n, m = len(pdds), len(pdds_)
361
        dm = np.empty((n, m))
362
        if verbose:
363
            bar = ProgressBar(max_value=n * m)
0 ignored issues
show
introduced by
Black listed name "bar"
Loading history...
364
        count = 0
365
        for i in range(n):
366
            for j in range(m):
367
                dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
368
                if verbose:
369
                    count += 1
370
                    bar.update(count)
371
    return dm
372
373
374
def PDD_pdist(
375
        pdds: List[np.ndarray],
376
        metric: str = 'chebyshev',
377
        n_jobs=None,
378
        verbose=0,
379
        backend='multiprocessing',
380
        **kwargs
381
) -> np.ndarray:
382
    """Compare a set of PDDs pairwise, returning a condensed distance matrix.
383
    Supports parallelisation via joblib. If using parallelisation, make sure to
384
    include a if __name__ == '__main__' guard around this function.
385
386
    Parameters
387
    ----------
388
    pdds : List[numpy.ndarray]
389
        A list of PDDs.
390
    metric : str or callable, default 'chebyshev'
391
        Usually PDD rows are compared with the Chebyshev/l-infinity distance.
392
        Can take any metric accepted by :func:`scipy.spatial.distance.pdist`.
393
    n_jobs : int, default None
394
        Maximum number of concurrent jobs for parallel processing with joblib.
395
        Set to -1 to use the maximum possible. Note that for small inputs (< 100), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
396
        using parallel processing may be slower than the default n_jobs=None.
397
    verbose : int, default 0
398
        Controls verbosity. If using parallel processing (n_jobs > 1), verbose is
399
        passed to :class:`joblib.Parallel`, where larger values = more verbosity.
400
        Otherwise, uses progressbar2 where the progress bar is either on or off.
401
    backend : str, default 'multiprocessing'
402
        Specifies the parallelization backend implementation. For a list of
403
        supported backends, see the backend argument of :class:`joblib.Parallel`.
404
405
    Returns
406
    -------
407
    numpy.ndarray
408
        Returns a condensed distance matrix. Collapses a square
409
        distance matrix into a vector just keeping the upper half. See
410
        :func:`scipy.spatial.distance.squareform` to convert to a square 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
411
        distance matrix or for more on condensed distance matrices.
412
    """
413
414
    kwargs.pop('return_transport', None)
415
416
    if n_jobs is not None and n_jobs > 1:
417
        # TODO: put results into preallocated empty array in place
0 ignored issues
show
Coding Style introduced by
TODO and FIXME comments should generally be avoided.
Loading history...
418
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
419
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
420
            for i, j in combinations(range(len(pdds)), 2)
421
        )
422
        dm = np.array(dm)
423
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
424
    else:
425
        m = len(pdds)
426
        cdm_len = (m * (m - 1)) // 2
427
        cdm = np.empty(cdm_len, dtype=np.double)
428
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
429
        if verbose:
430
            bar = ProgressBar(max_value=cdm_len)
0 ignored issues
show
introduced by
Black listed name "bar"
Loading history...
431
        for r, (i, j) in enumerate(inds):
432
            cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
433
            if verbose:
434
                bar.update(r)
435
    return dm
436
437
438
def emd(
439
        pdd: np.ndarray,
440
        pdd_: np.ndarray,
441
        metric: Optional[str] = 'chebyshev',
442
        return_transport: Optional[bool] = False,
443
        **kwargs):
444
    """Alias for amd.EMD()."""
445
    return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs)
446
447
448
def _unwrap_periodicset_list(psets_or_str, **reader_kwargs):
449
    """Valid input for compare (PeriodicSet, path, refcode, lists of such)
450
    --> 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
451
    list of PeriodicSets"""
452
453
    if isinstance(psets_or_str, PeriodicSet):
0 ignored issues
show
unused-code introduced by
Unnecessary "elif" after "return"
Loading history...
454
        return [psets_or_str]
455
    elif isinstance(psets_or_str, list):
456
        return [s for item in psets_or_str for s in _extract_periodicsets(item, **reader_kwargs)]
457
    else:
458
        return _extract_periodicsets(psets_or_str, **reader_kwargs)
459
460
461
def _extract_periodicsets(item, **reader_kwargs):
462
    """str (path/refocde), file or PeriodicSet --> list of PeriodicSets."""
463
464
    if isinstance(item, PeriodicSet):
0 ignored issues
show
unused-code introduced by
Unnecessary "elif" after "return"
Loading history...
465
        return [item]
466
    elif isinstance(item, str) and not os.path.isfile(item) and not os.path.isdir(item):
467
        reader_kwargs.pop('reader', None)
468
        return list(CSDReader(item, **reader_kwargs))
469
    else:
470
        reader_kwargs.pop('families', None)
471
        reader_kwargs.pop('refcodes', None)
472
        return list(CifReader(item, **reader_kwargs))
473