Test Failed
Push — master ( add8ad...560841 )
by Daniel
09:07
created

amd.compare.PDD_cdist()   C

Complexity

Conditions 9

Size

Total Lines 73
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 73
rs 6.6666
c 0
b 0
f 0
cc 9
nop 7

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
import warnings
5
from typing import List, Optional, Union, Tuple
6
from functools import partial
7
from itertools import combinations
8
import os
9
10
import numpy as np
11
import numpy.typing as npt
12
import pandas as pd
13
from scipy.spatial.distance import cdist, pdist, squareform
14
from joblib import Parallel, delayed
15
import tqdm
16
17
from .io import CifReader, CSDReader
18
from .calculate import AMD, PDD
19
from ._emd import network_simplex
20
from .periodicset import PeriodicSet, PeriodicSetType
21
from .utils import neighbours_from_distance_matrix
22
23
24
def compare(
25
        crystals: Union[str, PeriodicSetType, List],
26
        crystals_: Optional[Union[str, PeriodicSetType, List]] = None,
27
        by: str = 'AMD',
28
        k: int = 100,
29
        nearest: Optional[int] = None,
30
        reader: str = 'ase',
31
        remove_hydrogens: bool = False,
32
        disorder: str = 'skip',
33
        heaviest_component: bool = False,
34
        molecular_centres: bool = False,
35
        families: bool = False,
36
        show_warnings: bool = True,
37
        collapse_tol: float = 1e-4,
38
        metric: str = 'chebyshev',
39
        n_jobs: Optional[int] = None,
40
        backend: str = 'multiprocessing',
41
        verbose: bool = False,
42
        low_memory: bool = False,
43
) -> pd.DataFrame:
44
    r"""Given one or two paths to cifs/folders, lists of CSD refcodes or
45
    periodic sets, compare them and return a DataFrame of the distance
46
    matrix. Default is to comapre by AMD with k = 100. Accepts most
47
    keyword arguments accepted by :class:`CifReader <.io.CifReader>`,
48
    :class:`CSDReader <.io.CSDReader>` and functions from
49
    :mod:`.compare`.
50
51
    Parameters
52
    ----------
53
    crystals : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str
54
        One or a collection of paths, refcodes, file objects or
55
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
56
    crystals\_ : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str, optional
57
        One or a collection of paths, refcodes, file objects or
58
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
59
    by : str, default 'AMD'
60
        Use AMD or PDD to compare crystals.
61
    k : int, default 100
62
        Number of neighbour atoms to use for AMD/PDD.
63
    nearest : int, deafult None
64
        Find a number of nearest neighbours instead of a full distance
65
        matrix between crystals.
66
    reader : str, optional
67
        The backend package used to parse the CIF. The default is 
68
        :code:`ase`, :code:`pymatgen` and :code:`gemmi` are also
69
        accepted, as well as :code:`ccdc` if csd-python-api is 
70
        installed. The ccdc reader should be able to read any format
71
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
72
        been tested.
73
    remove_hydrogens : bool, optional
74
        Remove hydrogens from the crystals.
75
    disorder : str, optional
76
        Controls how disordered structures are handled. Default is
77
        ``skip`` which skips any crystal with disorder, since disorder
78
        conflicts with the periodic set model. To read disordered
79
        structures anyway, choose either :code:`ordered_sites` to remove
80
        atoms with disorder or :code:`all_sites` include all atoms
81
        regardless of disorder.
82
    heaviest_component : bool, optional, csd-python-api only
83
        Removes all but the heaviest molecule in
84
        the asymmeric unit, intended for removing solvents.
85
    molecular_centres : bool, default False, csd-python-api only
86
        Use the centres of molecules for comparison
87
        instead of centres of atoms.
88
    families : bool, optional, csd-python-api only
89
        Read all entries whose refcode starts with
90
        the given strings, or 'families' (e.g. giving 'DEBXIT' reads all
91
        entries with refcodes starting with DEBXIT).
92
    show_warnings : bool, optional
93
        Controls whether warnings that arise during reading are printed.
94
    collapse_tol: float, default 1e-4, ``by='PDD'`` only
95
        If two PDD rows have all elements closer
96
        than ``collapse_tol``, they are merged and weights are given to
97
        rows in proportion to the number of times they appeared.
98
    metric : str or callable, default 'chebyshev'
99
        The metric to compare AMDs/PDDs with. AMDs are compared directly
100
        with this metric. EMD is the metric used between PDDs, which
101
        requires giving a metric to use between PDD rows. Chebyshev
102
        (L-infinity) distance is the default. Accepts any metric
103
        accepted by :func:`scipy.spatial.distance.cdist`.
104
    n_jobs : int, default None, ``by='PDD'`` only
105
        Maximum number of concurrent jobs for
106
        parallel processing with :code:`joblib`. Set to -1 to use the
107
        maximum. Using parallel processing may be slower for small
108
        inputs.
109
    backend : str, default 'multiprocessing', ``by='PDD'`` only
110
        The parallelization backend implementation for PDD comparisons.
111
        For a list of supported backends, see the backend argument of
112
        :class:`joblib.Parallel`.
113
    verbose : bool, default False
114
        Prints a progress bar when reading crystals and comparing PDDs.
115
        If using parallel processing (n_jobs > 1), the verbose argument
116
        of :class:`joblib.Parallel` is used, otherwise uses tqdm.
117
    low_memory : bool, default False, ``by='AMD'`` only
118
        Use a slower but more memory efficient
119
        method for large collections of AMDs (metric 'chebyshev' only).
120
121
    Returns
122
    -------
123
    df : :class:`pandas.DataFrame`
124
        DataFrame of the distance matrix for the given crystals compared
125
        by the chosen invariant.
126
127
    Raises
128
    ------
129
    ValueError
130
        If by is not 'AMD' or 'PDD', if either set given have no valid
131
        crystals to compare, or if crystals or crystals\_ are an invalid
132
        type.
133
134
    Examples
135
    --------
136
    Compare everything in a .cif (deafult, AMD with k=100)::
137
138
        df = amd.compare('data.cif')
139
140
    Compare everything in one cif with all crystals in all cifs in a
141
    directory (PDD, k=50)::
142
143
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
144
145
    **Examples (csd-python-api only)**
146
147
    Compare two crystals by CSD refcode (PDD, k=50)::
148
149
        df = amd.compare('DEBXIT01', 'DEBXIT02', by='PDD', k=50)
150
151
    Compare everything in a refcode family (AMD, k=100)::
152
153
        df = amd.compare('DEBXIT', families=True)
154
    """
155
156
    by = by.upper()
157
    if by not in ('AMD', 'PDD'):
158
        msg = f"parameter 'by' accepts 'AMD' or 'PDD', passed {by}"
159
        raise ValueError(msg)
160
161
    reader_kwargs = {
162
        'reader': reader,
163
        'families': families,
164
        'remove_hydrogens': remove_hydrogens,
165
        'disorder': disorder,
166
        'heaviest_component': heaviest_component,
167
        'molecular_centres': molecular_centres,
168
        'show_warnings': show_warnings,
169
        'verbose': verbose,
170
    }
171
172
    pdd_kwargs = {
173
        'collapse': True,
174
        'collapse_tol': collapse_tol,
175
        'lexsort': False,
176
    }
177
    
178
    compare_kwargs = {
179
        'metric': metric,
180
        'n_jobs': n_jobs,
181
        'backend': backend,
182
        'verbose': verbose,
183
        'low_memory': low_memory,
184
    }
185
186
    crystals = _unwrap_periodicset_list(crystals, **reader_kwargs)
187
    if not crystals:
188
        raise ValueError('No valid crystals to compare in first set.')
189
    names = [s.name for s in crystals]
190
191
    if crystals_ is None:
192
        names_ = names
193
    else:
194
        crystals_ = _unwrap_periodicset_list(crystals_, **reader_kwargs)
195
        if not crystals_:
196
            raise ValueError('No valid crystals to compare in second set.')
197
        names_ = [s.name for s in crystals_]
198
199
    if by == 'AMD':
200
        invs = [AMD(s, k) for s in crystals]
201
        compare_kwargs.pop('n_jobs', None)
202
        compare_kwargs.pop('backend', None)
203
        compare_kwargs.pop('verbose', None)
204
205
        if crystals_ is None:
206
            dm = AMD_pdist(invs, **compare_kwargs)
207
        else:
208
            invs_ = [AMD(s, k) for s in crystals_]
209
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
210
211
    elif by == 'PDD':
212
        invs = [PDD(s, k, **pdd_kwargs) for s in crystals]
213
        compare_kwargs.pop('low_memory', None)
214
215
        if crystals_ is None:
216
            dm = PDD_pdist(invs, **compare_kwargs)
217
        else:
218
            invs_ = [PDD(s, k, **pdd_kwargs) for s in crystals_]
219
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
220
221
    if nearest:
222
        nn_dm, inds = neighbours_from_distance_matrix(nearest, dm)
223
        data = {}
224
        for i in range(nearest):
225
            data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
226
            data['DIST ' + str(i+1)] = nn_dm[:, i]
227
        df = pd.DataFrame(data, index=names)
228
    else:
229
        if dm.ndim == 1:
230
            dm = squareform(dm)
231
        df = pd.DataFrame(dm, index=names, columns=names_)
232
233
    return df
234
235
236
def EMD(
237
        pdd: npt.NDArray,
238
        pdd_: npt.NDArray,
239
        metric: Optional[str] = 'chebyshev',
240
        return_transport: Optional[bool] = False,
241
        **kwargs
242
) -> float:
243
    r"""Earth mover's distance (EMD) between two PDDs, aka the
244
    Wasserstein metric.
245
246
    Parameters
247
    ----------
248
    pdd : :class:`numpy.ndarray`
249
        PDD of a crystal.
250
    pdd\_ : :class:`numpy.ndarray`
251
        PDD of a crystal.
252
    metric : str or callable, default 'chebyshev'
253
        EMD between PDDs requires defining a distance between PDD rows.
254
        By default, Chebyshev (L-infinity) distance is chosen like with
255
        AMDs. Accepts any metric accepted by
256
        :func:`scipy.spatial.distance.cdist`.
257
    return_transport: bool, default False
258
        Instead return a tuple ``(emd, transport_plan)`` where
259
        transport_plan describes the optimal flow.
260
261
    Returns
262
    -------
263
    emd : float
264
        Earth mover's distance between two PDDs. If ``return_transport``
265
        is True, return a tuple (emd, transport_plan).
266
267
    Raises
268
    ------
269
    ValueError
270
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
271
        columns.
272
    """
273
274
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
275
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
276
277
    if return_transport:
278
        return emd_dist, transport_plan
279
    return emd_dist
280
281
282
def AMD_cdist(
283
        amds: npt.ArrayLike,
284
        amds_: npt.ArrayLike,
285
        metric: str = 'chebyshev',
286
        low_memory: bool = False,
287
        **kwargs
288
) -> npt.NDArray:
289
    r"""Compare two sets of AMDs with each other, returning a distance
290
    matrix. This function is essentially
291
    :func:`scipy.spatial.distance.cdist` with the default metric
292
    ``chebyshev`` and a low memory option.
293
294
    Parameters
295
    ----------
296
    amds : ArrayLike
297
        A list/array of AMDs.
298
    amds\_ : ArrayLike
299
        A list/array of AMDs.
300
    metric : str or callable, default 'chebyshev'
301
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
302
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
303
    low_memory : bool, default False
304
        Use a slower but more memory efficient method for large collections of
305
        AMDs (metric 'chebyshev' only).
306
307
    Returns
308
    -------
309
    dm : :class:`numpy.ndarray`
310
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` is the
311
        distance (given by ``metric``) between ``amds[i]`` and ``amds[j]``.
312
    """
313
314
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
315
316
    if len(amds.shape) == 1:
317
        amds = np.array([amds])
318
    if len(amds_.shape) == 1:
319
        amds_ = np.array([amds_])
320
321
    if low_memory:
322
        if metric != 'chebyshev':
323
            msg = "Using only allowed metric 'chebyshev' for low_memory"
324
            warnings.warn(msg, UserWarning)
325
326
        dm = np.empty((len(amds), len(amds_)))
327
        for i, amd_vec in enumerate(amds):
328
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
329
    else:
330
        dm = cdist(amds, amds_, metric=metric, **kwargs)
331
332
    return dm
333
334
335
def AMD_pdist(
336
        amds: npt.ArrayLike,
337
        metric: str = 'chebyshev',
338
        low_memory: bool = False,
339
        **kwargs
340
) -> npt.NDArray:
341
    """Compare a set of AMDs pairwise, returning a condensed distance
342
    matrix. This function is essentially
343
    :func:`scipy.spatial.distance.pdist` with the default metric
344
    ``chebyshev`` and a low memory parameter.
345
346
    Parameters
347
    ----------
348
    amds : ArrayLike
349
        An list/array of AMDs.
350
    metric : str or callable, default 'chebyshev'
351
        Usually AMDs are compared with the Chebyshev (L-infinity)
352
        distance. Accepts any metric accepted by
353
        :func:`scipy.spatial.distance.pdist`.
354
    low_memory : bool, default False
355
        Use a slower but more memory efficient method for large
356
        collections of AMDs (metric 'chebyshev' only).
357
358
    Returns
359
    -------
360
    cdm : :class:`numpy.ndarray`
361
        Returns a condensed distance matrix. Collapses a square distance
362
        matrix into a vector, just keeping the upper half. See the
363
        function :func:`squareform <scipy.spatial.distance.squareform>`
364
        from SciPy to convert to a symmetric square distance matrix.
365
    """
366
367
    amds = np.asarray(amds)
368
369
    if len(amds.shape) == 1:
370
        amds = np.array([amds])
371
372
    if low_memory:
373
        m = len(amds)
374
        if metric != 'chebyshev':
375
            msg = 'Using only implemented metric "chebyshev" for low_memory'
376
            warnings.warn(msg, UserWarning)
377
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.float64)
378
        ind = 0
379
        for i in range(m):
380
            ind_ = ind + m - i - 1
381
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
382
            ind = ind_
383
    else:
384
        cdm = pdist(amds, metric=metric, **kwargs)
385
386
    return cdm
387
388
389
def PDD_cdist(
390
        pdds: List[npt.NDArray],
391
        pdds_: List[npt.NDArray],
392
        metric: str = 'chebyshev',
393
        backend: str = 'multiprocessing',
394
        n_jobs: Optional[int] = None,
395
        verbose: bool = False,
396
        **kwargs
397
) -> npt.NDArray:
398
    r"""Compare two sets of PDDs with each other, returning a distance
399
    matrix. Supports parallel processing via joblib. If using
400
    parallelisation, make sure to include a if __name__ == '__main__'
401
    guard around this function.
402
403
    Parameters
404
    ----------
405
    pdds : List[:class:`numpy.ndarray`]
406
        A list of PDDs.
407
    pdds\_ : List[:class:`numpy.ndarray`]
408
        A list of PDDs.
409
    metric : str or callable, default 'chebyshev'
410
        Usually PDD rows are compared with the Chebyshev/l-infinity
411
        distance. Accepts any metric accepted by
412
        :func:`scipy.spatial.distance.cdist`.
413
    backend : str, default 'multiprocessing'
414
        The parallelization backend implementation. For a list of
415
        supported backends, see the backend argument of
416
        :class:`joblib.Parallel`.
417
    n_jobs : int, default None
418
        Maximum number of concurrent jobs for parallel processing with
419
        ``joblib``. Set to -1 to use the maximum. Using parallel
420
        processing may be slower for small inputs.
421
    verbose : bool, default False
422
        Prints a progress bar. If using parallel processing
423
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
424
        is used, otherwise uses tqdm.
425
426
    Returns
427
    -------
428
    dm : :class:`numpy.ndarray`
429
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
430
        :math:`ij` th entry is the distance between ``pdds[i]`` and
431
        ``pdds_[j]`` given by Earth mover's distance.
432
    """
433
434
    kwargs.pop('return_transport', None)
435
    k = pdds[0].shape[-1] - 1
436
    if verbose:
437
        verbose = 3
438
439
    if n_jobs is not None and n_jobs not in (0, 1):
440
        # TODO: put results into preallocated empty array in place
441
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
442
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
443
            for i in range(len(pdds)) for j in range(len(pdds_))
444
        )
445
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
446
447
    else:
448
        n, m = len(pdds), len(pdds_)
449
        dm = np.empty((n, m))
450
        if verbose:
451
            desc = f'Comparing {len(pdds)}x{len(pdds_)} items by PDD (k={k})'
452
            progress_bar = tqdm.tqdm(desc=desc, total=n*m)
453
        for i in range(n):
454
            for j in range(m):
455
                dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
456
                if verbose:
457
                    progress_bar.update(1)
458
        if verbose:
459
            progress_bar.close()
460
461
    return dm
462
463
464
def PDD_pdist(
465
        pdds: List[npt.NDArray],
466
        metric: str = 'chebyshev',
467
        backend: str = 'multiprocessing',
468
        n_jobs: Optional[int] = None,
469
        verbose: bool = False,
470
        **kwargs
471
) -> npt.NDArray:
472
    """Compare a set of PDDs pairwise, returning a condensed distance
473
    matrix. Supports parallelisation via joblib. If using
474
    parallelisation, make sure to include a if __name__ == '__main__'
475
    guard around this function.
476
477
    Parameters
478
    ----------
479
    pdds : List[:class:`numpy.ndarray`]
480
        A list of PDDs.
481
    metric : str or callable, default 'chebyshev'
482
        Usually PDD rows are compared with the Chebyshev/l-infinity
483
        distance. Accepts any metric accepted by
484
        :func:`scipy.spatial.distance.cdist`.
485
    backend : str, default 'multiprocessing'
486
        The parallelization backend implementation. For a list of
487
        supported backends, see the backend argument of
488
        :class:`joblib.Parallel`.
489
    n_jobs : int, default None
490
        Maximum number of concurrent jobs for parallel processing with
491
        ``joblib``. Set to -1 to use the maximum. Using parallel
492
        processing may be slower for small inputs.
493
    verbose : bool, default False
494
        Prints a progress bar. If using parallel processing
495
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
496
        is used, otherwise uses tqdm.
497
498
    Returns
499
    -------
500
    cdm : :class:`numpy.ndarray`
501
        Returns a condensed distance matrix. Collapses a square distance
502
        matrix into a vector, just keeping the upper half. See the
503
        function :func:`squareform <scipy.spatial.distance.squareform>`
504
        from SciPy to convert to a symmetric square distance matrix.
505
    """
506
507
    kwargs.pop('return_transport', None)
508
    k = pdds[0].shape[-1] - 1
509
    if verbose:
510
        verbose = 3
511
512
    if n_jobs is not None and n_jobs > 1:
513
        # TODO: put results into preallocated empty array in place
514
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
515
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
516
            for i, j in combinations(range(len(pdds)), 2)
517
        )
518
        cdm = np.array(cdm)
519
520
    else:
521
        m = len(pdds)
522
        cdm_len = (m * (m - 1)) // 2
523
        cdm = np.empty(cdm_len, dtype=np.float64)
524
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
525
        if verbose:
526
            desc = f'Comparing {len(pdds)} items pairwise by PDD (k={k})'
527
            progress_bar = tqdm.tqdm(desc=desc, total=cdm_len)
528
        for r, (i, j) in enumerate(inds):
529
            cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
530
            if verbose:
531
                progress_bar.update(1)
532
        if verbose:
533
            progress_bar.close()
534
    return cdm
535
536
537
def emd(pdd: npt.NDArray, pdd_: npt.NDArray, **kwargs) -> float:
538
    """Alias for :func:`EMD() <.compare.EMD>`."""
539
    return EMD(pdd, pdd_, **kwargs)
540
541
542
def _unwrap_periodicset_list(psets_or_str, **reader_kwargs):
543
    """Valid input for amd.compare() (``PeriodicSet``, path, refcode,
544
    lists of such) --> list of PeriodicSets.
545
    """
546
547
    def _extract_periodicsets(item, **reader_kwargs):
548
        """str (path/refcode), file or ``PeriodicSet`` --> list of
549
        ``PeriodicSets``.
550
        """
551
552
        if isinstance(item, PeriodicSet):
553
            return [item]
554
        if isinstance(psets_or_str, Tuple):
555
            return [PeriodicSet(psets_or_str[0], psets_or_str[1])]
556
        if isinstance(item, str) and not os.path.isfile(item) \
557
                                 and not os.path.isdir(item):
558
            reader_kwargs.pop('reader', None)
559
            return list(CSDReader(item, **reader_kwargs))
560
        reader_kwargs.pop('families', None)
561
        reader_kwargs.pop('refcodes', None)
562
        return list(CifReader(item, **reader_kwargs))
563
564
    if isinstance(psets_or_str, list):
565
        return [s for item in psets_or_str
566
                for s in _extract_periodicsets(item, **reader_kwargs)]
567
    return _extract_periodicsets(psets_or_str, **reader_kwargs)
568