Passed
Push — master ( eb28f3...7c7660 )
by Daniel
04:02
created

amd.compare.compare()   D

Complexity

Conditions 12

Size

Total Lines 205
Code Lines 75

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 75
dl 0
loc 205
rs 4.1781
c 0
b 0
f 0
cc 12
nop 17

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like amd.compare.compare() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
import warnings
5
from typing import List, Optional, Union, Tuple
6
from functools import partial
7
from itertools import combinations
8
import os
9
10
import numpy as np
11
import numpy.typing as npt
12
import pandas as pd
13
from scipy.spatial.distance import cdist, pdist, squareform
14
from joblib import Parallel, delayed
15
import tqdm
16
17
from .io import CifReader, CSDReader
18
from .calculate import AMD, PDD
19
from ._emd import network_simplex
20
from .periodicset import PeriodicSet, PeriodicSetType
21
from .utils import neighbours_from_distance_matrix
22
23
24
def compare(
25
        crystals: Union[str, PeriodicSetType, List],
26
        crystals_: Optional[Union[str, PeriodicSetType, List]] = None,
27
        by: str = 'AMD',
28
        k: int = 100,
29
        nearest: Optional[int] = None,
30
        reader: str = 'ase',
31
        remove_hydrogens: bool = False,
32
        disorder: str = 'skip',
33
        heaviest_component: bool = False,
34
        molecular_centres: bool = False,
35
        families: bool = False,
36
        show_warnings: bool = True,
37
        collapse_tol: float = 1e-4,
38
        metric: str = 'chebyshev',
39
        n_jobs: Optional[int] = None,
40
        verbose: int = 0,
41
        low_memory: bool = False,
42
) -> pd.DataFrame:
43
    r"""Given one or two paths to cifs/folders, lists of CSD refcodes or
44
    periodic sets, compare them and return a DataFrame of the distance
45
    matrix. Default is to comapre by AMD with k = 100. Accepts most
46
    keyword arguments accepted by :class:`CifReader <.io.CifReader>`,
47
    :class:`CSDReader <.io.CSDReader>` and functions from
48
    :mod:`.compare`.
49
50
    Parameters
51
    ----------
52
    crystals : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str
53
        One or a collection of paths, refcodes, file objects or
54
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
55
    crystals\_ : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str, optional
56
        One or a collection of paths, refcodes, file objects or
57
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
58
    by : str, default 'AMD'
59
        Use AMD or PDD to compare crystals.
60
    k : int, default 100
61
        Number of neighbour atoms to use for AMD/PDD.
62
    nearest : int, deafult None
63
        Find a number of nearest neighbours instead of a full distance
64
        matrix between crystals.
65
    reader : str, optional
66
        The backend package used to parse the CIF. The default is 
67
        :code:`ase`, :code:`pymatgen` and :code:`gemmi` are also
68
        accepted, as well as :code:`ccdc` if csd-python-api is 
69
        installed. The ccdc reader should be able to read any format
70
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
71
        been tested.
72
    remove_hydrogens : bool, optional
73
        Remove hydrogens from the crystals.
74
    disorder : str, optional
75
        Controls how disordered structures are handled. Default is
76
        ``skip`` which skips any crystal with disorder, since disorder
77
        conflicts with the periodic set model. To read disordered
78
        structures anyway, choose either :code:`ordered_sites` to remove
79
        atoms with disorder or :code:`all_sites` include all atoms
80
        regardless of disorder.
81
    heaviest_component : bool, optional, csd-python-api only
82
        Removes all but the heaviest molecule in
83
        the asymmeric unit, intended for removing solvents.
84
    molecular_centres : bool, default False, csd-python-api only
85
        Use the centres of molecules for comparison
86
        instead of centres of atoms.
87
    families : bool, optional, csd-python-api only
88
        Read all entries whose refcode starts with
89
        the given strings, or 'families' (e.g. giving 'DEBXIT' reads all
90
        entries with refcodes starting with DEBXIT).
91
    show_warnings : bool, optional
92
        Controls whether warnings that arise during reading are printed.
93
    collapse_tol: float, default 1e-4, ``by='PDD'`` only
94
        If two PDD rows have all elements closer
95
        than ``collapse_tol``, they are merged and weights are given to
96
        rows in proportion to the number of times they appeared.
97
    metric : str or callable, default 'chebyshev'
98
        The metric to compare AMDs/PDDs with. AMDs are compared directly
99
        with this metric. EMD is the metric used between PDDs, which
100
        requires giving a metric to use between PDD rows. Chebyshev
101
        (L-infinity) distance is the default. Accepts any metric
102
        accepted by :func:`scipy.spatial.distance.cdist`.
103
    n_jobs : int, default None, ``by='PDD'`` only
104
        Maximum number of concurrent jobs for
105
        parallel processing with :code:`joblib`. Set to -1 to use the
106
        maximum. Using parallel processing may be slower for small
107
        inputs.
108
    verbose : int, default 0, ``by='PDD'`` only
109
        Verbosity level. If using parallel processing (n_jobs > 1),
110
        passed to :class:`joblib.Parallel` where larger values = more
111
        verbose. Otherwise uses tqdm if verbose is True.
112
    low_memory : bool, default False, ``by='AMD'`` only
113
        Use a slower but more memory efficient
114
        method for large collections of AMDs (metric 'chebyshev' only).
115
116
    Returns
117
    -------
118
    df : :class:`pandas.DataFrame`
119
        DataFrame of the distance matrix for the given crystals compared
120
        by the chosen invariant.
121
122
    Raises
123
    ------
124
    ValueError
125
        If by is not 'AMD' or 'PDD', if either set given have no valid
126
        crystals to compare, or if crystals or crystals\_ are an invalid
127
        type.
128
129
    Examples
130
    --------
131
    Compare everything in a .cif (deafult, AMD with k=100)::
132
133
        df = amd.compare('data.cif')
134
135
    Compare everything in one cif with all crystals in all cifs in a
136
    directory (PDD, k=50)::
137
138
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
139
140
    **Examples (csd-python-api only)**
141
142
    Compare two crystals by CSD refcode (PDD, k=50)::
143
144
        df = amd.compare('DEBXIT01', 'DEBXIT02', by='PDD', k=50)
145
146
    Compare everything in a refcode family (AMD, k=100)::
147
148
        df = amd.compare('DEBXIT', families=True)
149
    """
150
151
    by = by.upper()
152
    if by not in ('AMD', 'PDD'):
153
        msg = f"parameter 'by' accepts 'AMD' or 'PDD', passed {by}"
154
        raise ValueError(msg)
155
156
    reader_kwargs = {
157
        'reader': reader,
158
        'families': families,
159
        'remove_hydrogens': remove_hydrogens,
160
        'disorder': disorder,
161
        'heaviest_component': heaviest_component,
162
        'molecular_centres': molecular_centres,
163
        'show_warnings': show_warnings,
164
    }
165
166
    pdd_kwargs = {
167
        'collapse': True,
168
        'collapse_tol': collapse_tol,
169
        'lexsort': False,
170
    }
171
    
172
    compare_kwargs = {
173
        'metric': metric,
174
        'n_jobs': n_jobs,
175
        'verbose': verbose,
176
        'low_memory': low_memory,
177
    }
178
179
    crystals = _unwrap_periodicset_list(crystals, **reader_kwargs)
180
    if not crystals:
181
        raise ValueError('No valid crystals to compare in first set.')
182
    names = [s.name for s in crystals]
183
184
    if crystals_ is None:
185
        names_ = names
186
    else:
187
        crystals_ = _unwrap_periodicset_list(crystals_, **reader_kwargs)
188
        if not crystals_:
189
            raise ValueError('No valid crystals to compare in second set.')
190
        names_ = [s.name for s in crystals_]
191
192
    if by == 'AMD':
193
194
        invs = [AMD(s, k) for s in crystals]
195
        compare_kwargs.pop('n_jobs', None)
196
        compare_kwargs.pop('verbose', None)
197
198
        if crystals_ is None:
199
            dm = AMD_pdist(invs, **compare_kwargs)
200
        else:
201
            invs_ = [AMD(s, k) for s in crystals_]
202
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
203
204
    elif by == 'PDD':
205
206
        invs = [PDD(s, k, **pdd_kwargs) for s in crystals]
207
        compare_kwargs.pop('low_memory', None)
208
209
        if crystals_ is None:
210
            dm = PDD_pdist(invs, **compare_kwargs)
211
        else:
212
            invs_ = [PDD(s, k, **pdd_kwargs) for s in crystals_]
213
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
214
215
    if nearest:
216
        nn_dm, inds = neighbours_from_distance_matrix(nearest, dm)
217
        data = {}
218
        for i in range(nearest):
219
            data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
220
            data['DIST ' + str(i+1)] = nn_dm[:, i]
221
        df = pd.DataFrame(data, index=names)
222
223
    else:
224
        if dm.ndim == 1:
225
            dm = squareform(dm)
226
        df = pd.DataFrame(dm, index=names, columns=names_)
227
228
    return df
229
230
231
def EMD(
232
        pdd: np.ndarray,
233
        pdd_: np.ndarray,
234
        metric: Optional[str] = 'chebyshev',
235
        return_transport: Optional[bool] = False,
236
        **kwargs
237
) -> float:
238
    r"""Earth mover's distance (EMD) between two PDDs, aka the
239
    Wasserstein metric.
240
241
    Parameters
242
    ----------
243
    pdd : :class:`numpy.ndarray`
244
        PDD of a crystal.
245
    pdd\_ : :class:`numpy.ndarray`
246
        PDD of a crystal.
247
    metric : str or callable, default 'chebyshev'
248
        EMD between PDDs requires defining a distance between PDD rows.
249
        By default, Chebyshev (L-infinity) distance is chosen like with
250
        AMDs. Accepts any metric accepted by
251
        :func:`scipy.spatial.distance.cdist`.
252
    return_transport: bool, default False
253
        Instead return a tuple ``(emd, transport_plan)`` where
254
        transport_plan describes the optimal flow.
255
256
    Returns
257
    -------
258
    emd : float
259
        Earth mover's distance between two PDDs. If ``return_transport``
260
        is True, return a tuple (emd, transport_plan).
261
262
    Raises
263
    ------
264
    ValueError
265
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
266
        columns.
267
    """
268
269
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
270
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
271
272
    if return_transport:
273
        return emd_dist, transport_plan
274
275
    return emd_dist
276
277
278
def AMD_cdist(
279
        amds: npt.ArrayLike,
280
        amds_: npt.ArrayLike,
281
        metric: str = 'chebyshev',
282
        low_memory: bool = False,
283
        **kwargs
284
) -> np.ndarray:
285
    r"""Compare two sets of AMDs with each other, returning a distance
286
    matrix. This function is essentially
287
    :func:`scipy.spatial.distance.cdist` with the default metric
288
    ``chebyshev`` and a low memory option.
289
290
    Parameters
291
    ----------
292
    amds : array_like
293
        A list/array of AMDs.
294
    amds\_ : array_like
295
        A list/array of AMDs.
296
    metric : str or callable, default 'chebyshev'
297
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
298
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
299
    low_memory : bool, default False
300
        Use a slower but more memory efficient method for large collections of
301
        AMDs (metric 'chebyshev' only).
302
303
    Returns
304
    -------
305
    dm : :class:`numpy.ndarray`
306
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` is the
307
        distance (given by ``metric``) between ``amds[i]`` and ``amds[j]``.
308
    """
309
310
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
311
312
    if len(amds.shape) == 1:
313
        amds = np.array([amds])
314
    if len(amds_.shape) == 1:
315
        amds_ = np.array([amds_])
316
317
    if low_memory:
318
        if metric != 'chebyshev':
319
            msg = "Using only allowed metric 'chebyshev' for low_memory"
320
            warnings.warn(msg, UserWarning)
321
322
        dm = np.empty((len(amds), len(amds_)))
323
        for i, amd_vec in enumerate(amds):
324
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
325
    else:
326
        dm = cdist(amds, amds_, metric=metric, **kwargs)
327
328
    return dm
329
330
331
def AMD_pdist(
332
        amds: npt.ArrayLike,
333
        metric: str = 'chebyshev',
334
        low_memory: bool = False,
335
        **kwargs
336
) -> np.ndarray:
337
    """Compare a set of AMDs pairwise, returning a condensed distance
338
    matrix. This function is essentially
339
    :func:`scipy.spatial.distance.pdist` with the default metric
340
    ``chebyshev`` and a low memory parameter.
341
342
    Parameters
343
    ----------
344
    amds : array_like
345
        An list/array of AMDs.
346
    metric : str or callable, default 'chebyshev'
347
        Usually AMDs are compared with the Chebyshev (L-infinity)
348
        distance. Accepts any metric accepted by
349
        :func:`scipy.spatial.distance.pdist`.
350
    low_memory : bool, default False
351
        Use a slower but more memory efficient method for large
352
        collections of AMDs (metric 'chebyshev' only).
353
354
    Returns
355
    -------
356
    cdm : :class:`numpy.ndarray`
357
        Returns a condensed distance matrix. Collapses a square distance
358
        matrix into a vector, just keeping the upper half. See the
359
        function :func:`squareform <scipy.spatial.distance.squareform>`
360
        from SciPy to convert to a symmetric square distance matrix.
361
    """
362
363
    amds = np.asarray(amds)
364
365
    if len(amds.shape) == 1:
366
        amds = np.array([amds])
367
368
    if low_memory:
369
        m = len(amds)
370
        if metric != 'chebyshev':
371
            msg = "Using only allowed metric 'chebyshev' for low_memory"
372
            warnings.warn(msg, UserWarning)
373
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
374
        ind = 0
375
        for i in range(m):
376
            ind_ = ind + m - i - 1
377
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
378
            ind = ind_
379
    else:
380
        cdm = pdist(amds, metric=metric, **kwargs)
381
382
    return cdm
383
384
385
def PDD_cdist(
386
        pdds: List[np.ndarray],
387
        pdds_: List[np.ndarray],
388
        metric: str = 'chebyshev',
389
        backend: str = 'multiprocessing',
390
        n_jobs: Optional[int] = None,
391
        verbose: int = 0,
392
        **kwargs
393
) -> np.ndarray:
394
    r"""Compare two sets of PDDs with each other, returning a distance
395
    matrix. Supports parallel processing via joblib. If using
396
    parallelisation, make sure to include a if __name__ == '__main__'
397
    guard around this function.
398
399
    Parameters
400
    ----------
401
    pdds : List[:class:`numpy.ndarray`]
402
        A list of PDDs.
403
    pdds\_ : List[:class:`numpy.ndarray`]
404
        A list of PDDs.
405
    metric : str or callable, default 'chebyshev'
406
        Usually PDD rows are compared with the Chebyshev/l-infinity
407
        distance. Accepts any metric accepted by
408
        :func:`scipy.spatial.distance.cdist`.
409
    backend : str, default 'multiprocessing'
410
        The parallelization backend implementation. For a list of
411
        supported backends, see the backend argument of
412
        :class:`joblib.Parallel`.
413
    n_jobs : int, default None
414
        Maximum number of concurrent jobs for parallel processing with
415
        ``joblib``. Set to -1 to use the maximum. Using parallel
416
        processing may be slower for small inputs.
417
    verbose : int, default 0
418
        Verbosity level. If using parallel processing (n_jobs > 1),
419
        passed to :class:`joblib.Parallel` where larger values = more
420
        verbose. Otherwise uses tqdm if verbose is True.
421
422
    Returns
423
    -------
424
    dm : :class:`numpy.ndarray`
425
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
426
        :math:`ij` th entry is the distance between ``pdds[i]`` and
427
        ``pdds_[j]`` given by Earth mover's distance.
428
    """
429
430
    if isinstance(pdds, np.ndarray):
431
        if len(pdds.shape) == 2:
432
            pdds = [pdds]
433
434
    if isinstance(pdds_, np.ndarray):
435
        if len(pdds_.shape) == 2:
436
            pdds_ = [pdds_]
437
438
    kwargs.pop('return_transport', None)
439
440
    if n_jobs is not None and n_jobs not in (0, 1):
441
        # TODO: put results into preallocated empty array in place
442
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
443
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
444
            for i in range(len(pdds)) for j in range(len(pdds_))
445
        )
446
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
447
448
    else:
449
        n, m = len(pdds), len(pdds_)
450
        dm = np.empty((n, m))
451
        if verbose:
452
            pbar = tqdm.tqdm(total=n*m)
453
        for i in range(n):
454
            for j in range(m):
455
                dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
456
                if verbose:
457
                    pbar.update(1)
458
        if verbose:
459
            pbar.close()
460
461
    return dm
462
463
464
def PDD_pdist(
465
        pdds: List[np.ndarray],
466
        metric: str = 'chebyshev',
467
        backend: str = 'multiprocessing',
468
        n_jobs: Optional[int] = None,
469
        verbose: int = 0,
470
        **kwargs
471
) -> np.ndarray:
472
    """Compare a set of PDDs pairwise, returning a condensed distance
473
    matrix. Supports parallelisation via joblib. If using
474
    parallelisation, make sure to include a if __name__ == '__main__'
475
    guard around this function.
476
477
    Parameters
478
    ----------
479
    pdds : List[:class:`numpy.ndarray`]
480
        A list of PDDs.
481
    metric : str or callable, default 'chebyshev'
482
        Usually PDD rows are compared with the Chebyshev/l-infinity
483
        distance. Accepts any metric accepted by
484
        :func:`scipy.spatial.distance.cdist`.
485
    backend : str, default 'multiprocessing'
486
        The parallelization backend implementation. For a list of
487
        supported backends, see the backend argument of
488
        :class:`joblib.Parallel`.
489
    n_jobs : int, default None
490
        Maximum number of concurrent jobs for parallel processing with
491
        ``joblib``. Set to -1 to use the maximum. Using parallel
492
        processing may be slower for small inputs.
493
    verbose : int, default 0
494
        Verbosity level. If using parallel processing (n_jobs > 1),
495
        passed to :class:`joblib.Parallel` where larger values = more
496
        verbose. Otherwise uses tqdm if verbose is True.
497
498
    Returns
499
    -------
500
    cdm : :class:`numpy.ndarray`
501
        Returns a condensed distance matrix. Collapses a square distance
502
        matrix into a vector, just keeping the upper half. See the
503
        function :func:`squareform <scipy.spatial.distance.squareform>`
504
        from SciPy to convert to a symmetric square distance matrix.
505
    """
506
507
    kwargs.pop('return_transport', None)
508
509
    if n_jobs is not None and n_jobs > 1:
510
        # TODO: put results into preallocated empty array in place
511
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
512
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
513
            for i, j in combinations(range(len(pdds)), 2)
514
        )
515
        cdm = np.array(cdm)
516
517
    else:
518
        m = len(pdds)
519
        cdm_len = (m * (m - 1)) // 2
520
        cdm = np.empty(cdm_len, dtype=np.double)
521
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
522
        if verbose:
523
            eta = tqdm.tqdm(cdm_len)
524
        for r, (i, j) in enumerate(inds):
525
            cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
526
            if verbose:
527
                eta.update(1)
528
        if verbose:
529
            eta.close()
530
    return cdm
531
532
533
def emd(pdd: np.ndarray, pdd_: np.ndarray, **kwargs) -> float:
534
    """Alias for :func:`EMD() <.compare.EMD>`."""
535
    return EMD(pdd, pdd_, **kwargs)
536
537
538
def _unwrap_periodicset_list(psets_or_str, **reader_kwargs):
539
    """Valid input for amd.compare() (``PeriodicSet``, path, refcode,
540
    lists of such) --> list of PeriodicSets.
541
    """
542
543
    def _extract_periodicsets(item, **reader_kwargs):
544
        """str (path/refcode), file or ``PeriodicSet`` --> list of
545
        ``PeriodicSets``.
546
        """
547
548
        if isinstance(item, PeriodicSet):
549
            return [item]
550
        if isinstance(psets_or_str, Tuple):
551
            return [PeriodicSet(psets_or_str[0], psets_or_str[1])]
552
        if isinstance(item, str) and not os.path.isfile(item) \
553
                                 and not os.path.isdir(item):
554
            reader_kwargs.pop('reader', None)
555
            return list(CSDReader(item, **reader_kwargs))
556
        reader_kwargs.pop('families', None)
557
        reader_kwargs.pop('refcodes', None)
558
        return list(CifReader(item, **reader_kwargs))
559
560
    if isinstance(psets_or_str, list):
561
        return [s for item in psets_or_str
562
                for s in _extract_periodicsets(item, **reader_kwargs)]
563
    return _extract_periodicsets(psets_or_str, **reader_kwargs)
564