Passed
Push — master ( bc96b3...eb28f3 )
by Daniel
06:34
created

amd.compare._unwrap_periodicset_list()   B

Complexity

Conditions 7

Size

Total Lines 26
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 26
rs 8
c 0
b 0
f 0
cc 7
nop 2
1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
import warnings
5
from typing import List, Optional, Union
6
from functools import partial
7
from itertools import combinations
8
import os
9
10
import numpy as np
11
import pandas as pd
12
from scipy.spatial.distance import cdist, pdist, squareform
13
from joblib import Parallel, delayed
14
import tqdm
15
16
from .io import CifReader, CSDReader
17
from .calculate import AMD, PDD
18
from ._emd import network_simplex
19
from .periodicset import PeriodicSet
20
from .utils import neighbours_from_distance_matrix
21
22
23
def compare(
24
        crystals,
25
        crystals_=None,
26
        by='AMD',
27
        k=100,
28
        nearest=None,
29
        reader='ase',
30
        remove_hydrogens=False,
31
        disorder='skip',
32
        heaviest_component=False,
33
        molecular_centres=False,
34
        families=False,
35
        show_warnings=True,
36
        collapse_tol=1e-4,
37
        metric='chebyshev',
38
        n_jobs=None,
39
        verbose=0,
40
        low_memory=False,
41
) -> pd.DataFrame:
42
    r"""Given one or two sets of periodic set(s), paths to cif(s) or
43
    refcode(s), compare them and return a DataFrame of the distance
44
    matrix. Default is to comapre by AMD with k = 100. Accepts most
45
    keyword arguments accepted by :class:`CifReader <.io.CifReader>`,
46
    :class:`CSDReader <.io.CSDReader>` and functions from
47
    :mod:`.compare`.
48
49
    Parameters
50
    ----------
51
    crystals : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str
52
        One or a collection of paths, refcodes, file objects or
53
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
54
    crystals\_ : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str, optional
55
        One or a collection of paths, refcodes, file objects or
56
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
57
    by : str, default 'AMD'
58
        Use AMD or PDD to compare crystals.
59
    k : int, default 100
60
        Number of neighbour atoms to use for AMD/PDD.
61
    nearest : int, deafult None
62
        Find a number of nearest neighbours instead of a full distance
63
        matrix between crystals.
64
    reader : str, optional
65
        The backend package used to parse the CIF. The default is 
66
        :code:`ase`, :code:`pymatgen` and :code:`gemmi` are also
67
        accepted, as well as :code:`ccdc` if csd-python-api is 
68
        installed. The ccdc reader should be able to read any format
69
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
70
        been tested.
71
    remove_hydrogens : bool, optional
72
        Remove hydrogens from the crystals.
73
    disorder : str, optional
74
        Controls how disordered structures are handled. Default is
75
        ``skip`` which skips any crystal with disorder, since disorder
76
        conflicts with the periodic set model. To read disordered
77
        structures anyway, choose either :code:`ordered_sites` to remove
78
        atoms with disorder or :code:`all_sites` include all atoms
79
        regardless of disorder.
80
    heaviest_component : bool, optional, csd-python-api only
81
        Removes all but the heaviest molecule in
82
        the asymmeric unit, intended for removing solvents.
83
    molecular_centres : bool, default False, csd-python-api only
84
        Use the centres of molecules for comparison
85
        instead of centres of atoms.
86
    families : bool, optional, csd-python-api only
87
        Read all entries whose refcode starts with
88
        the given strings, or 'families' (e.g. giving 'DEBXIT' reads all
89
        entries with refcodes starting with DEBXIT).
90
    show_warnings : bool, optional
91
        Controls whether warnings that arise during reading are printed.
92
    collapse_tol: float, default 1e-4, ``by='PDD'`` only
93
        If two PDD rows have all elements closer
94
        than ``collapse_tol``, they are merged and weights are given to
95
        rows in proportion to the number of times they appeared.
96
    metric : str or callable, default 'chebyshev'
97
        The metric to compare AMDs/PDDs with. AMDs are compared directly
98
        with this metric. EMD is the metric used between PDDs, which
99
        requires giving a metric to use between PDD rows. Chebyshev
100
        (L-infinity) distance is the default. Accepts any metric
101
        accepted by :func:`scipy.spatial.distance.cdist`.
102
    n_jobs : int, default None, ``by='PDD'`` only
103
        Maximum number of concurrent jobs for
104
        parallel processing with :code:`joblib`. Set to -1 to use the
105
        maximum. Using parallel processing may be slower for small
106
        inputs.
107
    verbose : int, default 0, ``by='PDD'`` only
108
        Verbosity level. If using parallel processing (n_jobs > 1),
109
        passed to :class:`joblib.Parallel` where larger values = more
110
        verbose. Otherwise uses tqdm if verbose is True.
111
    low_memory : bool, default False, ``by='AMD'`` only
112
        Use a slower but more memory efficient
113
        method for large collections of AMDs (metric 'chebyshev' only).
114
115
    Returns
116
    -------
117
    df : :class:`pandas.DataFrame`
118
        DataFrame of the distance matrix for the given crystals compared
119
        by the chosen invariant.
120
121
    Raises
122
    ------
123
    ValueError
124
        If by is not 'AMD' or 'PDD', if either set given have no valid
125
        crystals to compare, or if crystals or crystals\_ are an invalid
126
        type.
127
128
    Examples
129
    --------
130
    Compare everything in a .cif (deafult, AMD with k=100)::
131
132
        df = amd.compare('data.cif')
133
134
    Compare everything in one cif with all crystals in all cifs in a
135
    directory (PDD, k=50)::
136
137
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
138
139
    **Examples (csd-python-api only)**
140
141
    Compare two crystals by CSD refcode (PDD, k=50)::
142
143
        df = amd.compare('DEBXIT01', 'DEBXIT02', by='PDD', k=50)
144
145
    Compare everything in a refcode family (AMD, k=100)::
146
147
        df = amd.compare('DEBXIT', families=True)
148
    """
149
150
    by = by.upper()
151
    if by not in ('AMD', 'PDD'):
152
        msg = f"parameter 'by' accepts 'AMD' or 'PDD', passed {by}"
153
        raise ValueError(msg)
154
155
    reader_kwargs = {
156
        'reader': reader,
157
        'families': families,
158
        'remove_hydrogens': remove_hydrogens,
159
        'disorder': disorder,
160
        'heaviest_component': heaviest_component,
161
        'molecular_centres': molecular_centres,
162
        'show_warnings': show_warnings,
163
    }
164
165
    pdd_kwargs = {
166
        'collapse': True,
167
        'collapse_tol': collapse_tol,
168
        'lexsort': False,
169
    }
170
    
171
    compare_kwargs = {
172
        'metric': metric,
173
        'n_jobs': n_jobs,
174
        'verbose': verbose,
175
        'low_memory': low_memory,
176
    }
177
178
    crystals = _unwrap_periodicset_list(crystals, **reader_kwargs)
179
    if not crystals:
180
        raise ValueError('No valid crystals to compare in first set.')
181
    names = [s.name for s in crystals]
182
183
    if crystals_ is None:
184
        names_ = names
185
    else:
186
        crystals_ = _unwrap_periodicset_list(crystals_, **reader_kwargs)
187
        if not crystals_:
188
            raise ValueError('No valid crystals to compare in second set.')
189
        names_ = [s.name for s in crystals_]
190
191
    if reader_kwargs['molecular_centres']:
192
        crystals = [(c.molecular_centres, c.cell) for c in crystals]
193
        if crystals_:
194
            crystals_ = [(c.molecular_centres, c.cell) for c in crystals_]
195
196
    if by == 'AMD':
197
198
        invs = [AMD(s, k) for s in crystals]
199
        compare_kwargs.pop('n_jobs', None)
200
        compare_kwargs.pop('verbose', None)
201
202
        if crystals_ is None:
203
            dm = AMD_pdist(invs, **compare_kwargs)
204
        else:
205
            invs_ = [AMD(s, k) for s in crystals_]
206
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
207
208
    elif by == 'PDD':
209
210
        invs = [PDD(s, k, **pdd_kwargs) for s in crystals]
211
        compare_kwargs.pop('low_memory', None)
212
213
        if crystals_ is None:
214
            dm = PDD_pdist(invs, **compare_kwargs)
215
        else:
216
            invs_ = [PDD(s, k, **pdd_kwargs) for s in crystals_]
217
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
218
219
    if nearest:
220
        nn_dm, inds = neighbours_from_distance_matrix(nearest, dm)
221
        data = {}
222
        for i in range(nearest):
223
            data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
224
            data['DIST ' + str(i+1)] = nn_dm[:, i]
225
        df = pd.DataFrame(data, index=names)
226
227
    else:
228
        if dm.ndim == 1:
229
            dm = squareform(dm)
230
        df = pd.DataFrame(dm, index=names, columns=names_)
231
232
    return df
233
234
235
def EMD(
236
        pdd: np.ndarray,
237
        pdd_: np.ndarray,
238
        metric: Optional[str] = 'chebyshev',
239
        return_transport: Optional[bool] = False,
240
        **kwargs):
241
    r"""Earth mover's distance (EMD) between two PDDs, aka the
242
    Wasserstein metric.
243
244
    Parameters
245
    ----------
246
    pdd : :class:`numpy.ndarray`
247
        PDD of a crystal.
248
    pdd\_ : :class:`numpy.ndarray`
249
        PDD of a crystal.
250
    metric : str or callable, default 'chebyshev'
251
        EMD between PDDs requires defining a distance between PDD rows.
252
        By default, Chebyshev (L-infinity) distance is chosen like with
253
        AMDs. Accepts any metric accepted by
254
        :func:`scipy.spatial.distance.cdist`.
255
    return_transport: bool, default False
256
        Instead return a tuple ``(emd, transport_plan)`` where
257
        transport_plan describes the optimal flow.
258
259
    Returns
260
    -------
261
    emd : float
262
        Earth mover's distance between two PDDs. If ``return_transport``
263
        is True, return a tuple (emd, transport_plan).
264
265
    Raises
266
    ------
267
    ValueError
268
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
269
        columns.
270
    """
271
272
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
273
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
274
275
    if return_transport:
276
        return emd_dist, transport_plan
277
278
    return emd_dist
279
280
281
def AMD_cdist(
282
        amds: Union[np.ndarray, List[np.ndarray]],
283
        amds_: Union[np.ndarray, List[np.ndarray]],
284
        metric: str = 'chebyshev',
285
        low_memory: bool = False,
286
        **kwargs
287
) -> np.ndarray:
288
    r"""Compare two sets of AMDs with each other, returning a distance
289
    matrix. This function is essentially
290
    :func:`scipy.spatial.distance.cdist` with the default metric
291
    ``chebyshev`` and a low memory option.
292
293
    Parameters
294
    ----------
295
    amds : array_like
296
        A list/array of AMDs.
297
    amds\_ : array_like
298
        A list/array of AMDs.
299
    metric : str or callable, default 'chebyshev'
300
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
301
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
302
    low_memory : bool, default False
303
        Use a slower but more memory efficient method for large collections of
304
        AMDs (metric 'chebyshev' only).
305
306
    Returns
307
    -------
308
    dm : :class:`numpy.ndarray`
309
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` is the
310
        distance (given by ``metric``) between ``amds[i]`` and ``amds[j]``.
311
    """
312
313
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
314
315
    if len(amds.shape) == 1:
316
        amds = np.array([amds])
317
    if len(amds_.shape) == 1:
318
        amds_ = np.array([amds_])
319
320
    if low_memory:
321
        if metric != 'chebyshev':
322
            msg = "Using only allowed metric 'chebyshev' for low_memory"
323
            warnings.warn(msg, UserWarning)
324
325
        dm = np.empty((len(amds), len(amds_)))
326
        for i, amd_vec in enumerate(amds):
327
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
328
    else:
329
        dm = cdist(amds, amds_, metric=metric, **kwargs)
330
331
    return dm
332
333
334
def AMD_pdist(
335
        amds: Union[np.ndarray, List[np.ndarray]],
336
        metric: str = 'chebyshev',
337
        low_memory: bool = False,
338
        **kwargs
339
) -> np.ndarray:
340
    """Compare a set of AMDs pairwise, returning a condensed distance
341
    matrix. This function is essentially
342
    :func:`scipy.spatial.distance.pdist` with the default metric
343
    ``chebyshev`` and a low memory parameter.
344
345
    Parameters
346
    ----------
347
    amds : array_like
348
        An list/array of AMDs.
349
    metric : str or callable, default 'chebyshev'
350
        Usually AMDs are compared with the Chebyshev (L-infinity)
351
        distance. Accepts any metric accepted by
352
        :func:`scipy.spatial.distance.pdist`.
353
    low_memory : bool, default False
354
        Use a slower but more memory efficient method for large
355
        collections of AMDs (metric 'chebyshev' only).
356
357
    Returns
358
    -------
359
    cdm : :class:`numpy.ndarray`
360
        Returns a condensed distance matrix. Collapses a square distance
361
        matrix into a vector, just keeping the upper half. See the
362
        function :func:`squareform <scipy.spatial.distance.squareform>`
363
        from SciPy to convert to a symmetric square distance matrix.
364
    """
365
366
    amds = np.asarray(amds)
367
368
    if len(amds.shape) == 1:
369
        amds = np.array([amds])
370
371
    if low_memory:
372
        m = len(amds)
373
        if metric != 'chebyshev':
374
            msg = "Using only allowed metric 'chebyshev' for low_memory"
375
            warnings.warn(msg, UserWarning)
376
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
377
        ind = 0
378
        for i in range(m):
379
            ind_ = ind + m - i - 1
380
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
381
            ind = ind_
382
    else:
383
        cdm = pdist(amds, metric=metric, **kwargs)
384
385
    return cdm
386
387
388
def PDD_cdist(
389
        pdds: List[np.ndarray],
390
        pdds_: List[np.ndarray],
391
        metric: str = 'chebyshev',
392
        backend='multiprocessing',
393
        n_jobs=None,
394
        verbose=0,
395
        **kwargs
396
) -> np.ndarray:
397
    r"""Compare two sets of PDDs with each other, returning a distance
398
    matrix. Supports parallel processing via joblib. If using
399
    parallelisation, make sure to include a if __name__ == '__main__'
400
    guard around this function.
401
402
    Parameters
403
    ----------
404
    pdds : List[:class:`numpy.ndarray`]
405
        A list of PDDs.
406
    pdds\_ : List[:class:`numpy.ndarray`]
407
        A list of PDDs.
408
    metric : str or callable, default 'chebyshev'
409
        Usually PDD rows are compared with the Chebyshev/l-infinity
410
        distance. Accepts any metric accepted by
411
        :func:`scipy.spatial.distance.cdist`.
412
    backend : str, default 'multiprocessing'
413
        The parallelization backend implementation. For a list of
414
        supported backends, see the backend argument of
415
        :class:`joblib.Parallel`.
416
    n_jobs : int, default None
417
        Maximum number of concurrent jobs for parallel processing with
418
        ``joblib``. Set to -1 to use the maximum. Using parallel
419
        processing may be slower for small inputs.
420
    verbose : int, default 0
421
        Verbosity level. If using parallel processing (n_jobs > 1),
422
        passed to :class:`joblib.Parallel` where larger values = more
423
        verbose. Otherwise uses tqdm if verbose is True.
424
425
    Returns
426
    -------
427
    dm : :class:`numpy.ndarray`
428
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
429
        :math:`ij` th entry is the distance between ``pdds[i]`` and
430
        ``pdds_[j]`` given by Earth mover's distance.
431
    """
432
433
    if isinstance(pdds, np.ndarray):
434
        if len(pdds.shape) == 2:
435
            pdds = [pdds]
436
437
    if isinstance(pdds_, np.ndarray):
438
        if len(pdds_.shape) == 2:
439
            pdds_ = [pdds_]
440
441
    kwargs.pop('return_transport', None)
442
443
    if n_jobs is not None and n_jobs not in (0, 1):
444
        # TODO: put results into preallocated empty array in place
445
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
446
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
447
            for i in range(len(pdds)) for j in range(len(pdds_))
448
        )
449
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
450
451
    else:
452
        n, m = len(pdds), len(pdds_)
453
        dm = np.empty((n, m))
454
        if verbose:
455
            pbar = tqdm.tqdm(total=n*m)
456
        for i in range(n):
457
            for j in range(m):
458
                dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
459
                if verbose:
460
                    pbar.update(1)
461
        if verbose:
462
            pbar.close()
463
464
    return dm
465
466
467
def PDD_pdist(
468
        pdds: List[np.ndarray],
469
        metric: str = 'chebyshev',
470
        backend='multiprocessing',
471
        n_jobs=None,
472
        verbose=0,
473
        **kwargs
474
) -> np.ndarray:
475
    """Compare a set of PDDs pairwise, returning a condensed distance
476
    matrix. Supports parallelisation via joblib. If using
477
    parallelisation, make sure to include a if __name__ == '__main__'
478
    guard around this function.
479
480
    Parameters
481
    ----------
482
    pdds : List[:class:`numpy.ndarray`]
483
        A list of PDDs.
484
    metric : str or callable, default 'chebyshev'
485
        Usually PDD rows are compared with the Chebyshev/l-infinity
486
        distance. Accepts any metric accepted by
487
        :func:`scipy.spatial.distance.cdist`.
488
    backend : str, default 'multiprocessing'
489
        The parallelization backend implementation. For a list of
490
        supported backends, see the backend argument of
491
        :class:`joblib.Parallel`.
492
    n_jobs : int, default None
493
        Maximum number of concurrent jobs for parallel processing with
494
        ``joblib``. Set to -1 to use the maximum. Using parallel
495
        processing may be slower for small inputs.
496
    verbose : int, default 0
497
        Verbosity level. If using parallel processing (n_jobs > 1),
498
        passed to :class:`joblib.Parallel` where larger values = more
499
        verbose. Otherwise uses tqdm if verbose is True.
500
501
    Returns
502
    -------
503
    cdm : :class:`numpy.ndarray`
504
        Returns a condensed distance matrix. Collapses a square distance
505
        matrix into a vector, just keeping the upper half. See the
506
        function :func:`squareform <scipy.spatial.distance.squareform>`
507
        from SciPy to convert to a symmetric square distance matrix.
508
    """
509
510
    kwargs.pop('return_transport', None)
511
512
    if n_jobs is not None and n_jobs > 1:
513
        # TODO: put results into preallocated empty array in place
514
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
515
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
516
            for i, j in combinations(range(len(pdds)), 2)
517
        )
518
        cdm = np.array(cdm)
519
520
    else:
521
        m = len(pdds)
522
        cdm_len = (m * (m - 1)) // 2
523
        cdm = np.empty(cdm_len, dtype=np.double)
524
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
525
        if verbose:
526
            eta = tqdm.tqdm(cdm_len)
527
        for r, (i, j) in enumerate(inds):
528
            cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
529
            if verbose:
530
                eta.update(1)
531
        if verbose:
532
            eta.close()
533
    return cdm
534
535
536
def emd(pdd: np.ndarray, pdd_: np.ndarray, **kwargs):
537
    """Alias for :func:`EMD() <.compare.EMD>`."""
538
    return EMD(pdd, pdd_, **kwargs)
539
540
541
def _unwrap_periodicset_list(psets_or_str, **reader_kwargs):
542
    """Valid input for amd.compare() (``PeriodicSet``, path, refcode,
543
    lists of such) --> list of PeriodicSets.
544
    """
545
546
    def _extract_periodicsets(item, **reader_kwargs):
547
        """str (path/refcode), file or ``PeriodicSet`` --> list of
548
        ``PeriodicSets``.
549
        """
550
551
        if isinstance(item, PeriodicSet):
552
            return [item]
553
        if isinstance(item, str) and not os.path.isfile(item) \
554
                                 and not os.path.isdir(item):
555
            reader_kwargs.pop('reader', None)
556
            return list(CSDReader(item, **reader_kwargs))
557
        reader_kwargs.pop('families', None)
558
        reader_kwargs.pop('refcodes', None)
559
        return list(CifReader(item, **reader_kwargs))
560
561
    if isinstance(psets_or_str, PeriodicSet):
562
        return [psets_or_str]
563
    if isinstance(psets_or_str, list):
564
        return [s for item in psets_or_str
565
                for s in _extract_periodicsets(item, **reader_kwargs)]
566
    return _extract_periodicsets(psets_or_str, **reader_kwargs)
567