Passed
Push — master ( e66c5e...1f4459 )
by Daniel
07:53
created

amd.compare   C

Complexity

Total Complexity 57

Size/Duplication

Total Lines 515
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 57
eloc 216
dl 0
loc 515
rs 5.04
c 0
b 0
f 0

9 Functions

Rating   Name   Duplication   Size   Complexity  
B AMD_pdist() 0 52 5
D PDD_cdist() 0 77 12
B PDD_pdist() 0 67 7
A _unwrap_periodicset_list() 0 11 3
A EMD() 0 44 2
A emd() 0 3 1
F compare() 0 158 16
A _extract_periodicsets() 0 13 5
B AMD_cdist() 0 51 6

How to fix   Complexity   

Complexity

Complex classes like amd.compare often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
import warnings
5
from typing import List, Optional, Union
6
from functools import partial
7
from itertools import combinations
8
import os
9
10
import numpy as np
11
import pandas as pd
12
from scipy.spatial.distance import cdist, pdist, squareform
13
from joblib import Parallel, delayed
14
import tqdm
15
16
from .amdio import CifReader, CSDReader
17
from .calculate import AMD, PDD
18
from ._emd import network_simplex
19
from .periodicset import PeriodicSet
20
from .utils import neighbours_from_distance_matrix
21
22
23
def compare(
24
        crystals,
25
        crystals_=None,
26
        by='AMD',
27
        k=100,
28
        nearest=None,
29
        **kwargs
30
) -> pd.DataFrame:
31
    r"""Given one or two sets of periodic set(s), paths to cif(s) or
32
    refcode(s), compare them and return a DataFrame of the distance
33
    matrix. Default is to comapre by AMD with k = 100. Accepts most
34
    keyword arguments accepted by :class:`CifReader <.amdio.CifReader>`,
35
    :class:`CSDReader <.amdio.CSDReader>` and functions from
36
    :mod:`.compare`, for a full list see the documentation Quick Start
37
    page. Note that using refcodes requires ``csd-python-api``.
38
39
    Parameters
40
    ----------
41
    crystals : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str
42
        One or a collection of paths, refcodes, file objects or
43
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
44
    crystals\_ : list of :class:`PeriodicSet <.periodicset.PeriodicSet>` or str, optional
45
        One or a collection of paths, refcodes, file objects or
46
        :class:`PeriodicSets <.periodicset.PeriodicSet>`.
47
    by : str, default 'AMD'
48
        Use AMD or PDD to compare crystals.
49
    k : int, default 100
50
        Number of neighbour atoms to use for AMD/PDD.
51
    nearest : int, deafult None
52
        Find a number of nearest neighbours instead of a full distance
53
        matrix between crystals.
54
    **kwargs :
55
        Optional arguments to be passed to io, calculate or compare
56
        functions.
57
58
    Returns
59
    -------
60
    df : :class:`pandas.DataFrame`
61
        DataFrame of the distance matrix for the given crystals compared
62
        by the chosen invariant.
63
64
    Raises
65
    ------
66
    ValueError
67
        If by is not 'AMD' or 'PDD', if either set given have no valid
68
        crystals to compare, or if crystals or crystals\_ are an invalid
69
        type.
70
71
    Examples
72
    --------
73
    Compare everything in a .cif (deafult, AMD with k=100)::
74
75
        df = amd.compare('data.cif')
76
77
    Compare everything in one cif with all crystals in all cifs in a
78
    directory (PDD, k=50)::
79
80
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
81
82
    **Examples (csd-python-api only)**
83
84
    Compare two crystals by CSD refcode (PDD, k=50)::
85
86
        df = amd.compare('DEBXIT01', 'DEBXIT02', by='PDD', k=50)
87
88
    Compare everything in a refcode family (AMD, k=100)::
89
90
        df = amd.compare('DEBXIT', families=True)
91
    """
92
93
    by = by.upper()
94
    if by not in ('AMD', 'PDD'):
95
        msg = f"parameter 'by' accepts 'AMD' or 'PDD', was passed {by}"
96
        raise ValueError(msg)
97
98
    # redo this way of doing things?
99
    reader_kwargs = {
100
        'reader': 'ase',
101
        'families': False,
102
        'remove_hydrogens': False,
103
        'disorder': 'skip',
104
        'heaviest_component': False,
105
        'molecular_centres': False,
106
        'show_warnings': True,
107
    }
108
109
    calc_kwargs = {
110
        'collapse': True,
111
        'collapse_tol': 1e-4,
112
        'lexsort': False,
113
    }
114
115
    compare_kwargs = {
116
        'metric': 'chebyshev',
117
        'n_jobs': None,
118
        'verbose': 0,
119
        'low_memory': False,
120
    }
121
122
    for default_kwargs in (reader_kwargs, calc_kwargs, compare_kwargs):
123
        for key in kwargs.keys() & default_kwargs.keys():
124
            default_kwargs[key] = kwargs[key]
125
126
    crystals = _unwrap_periodicset_list(crystals, **reader_kwargs)
127
    if not crystals:
128
        raise ValueError('No valid crystals to compare in first set.')
129
    names = [s.name for s in crystals]
130
131
    if crystals_ is None:
132
        names_ = names
133
    else:
134
        crystals_ = _unwrap_periodicset_list(crystals_, **reader_kwargs)
135
        if not crystals_:
136
            raise ValueError('No valid crystals to compare in second set.')
137
        names_ = [s.name for s in crystals_]
138
139
    if reader_kwargs['molecular_centres']:
140
        crystals = [(c.molecular_centres, c.cell) for c in crystals]
141
        if crystals_:
142
            crystals_ = [(c.molecular_centres, c.cell) for c in crystals_]
143
144
    if by == 'AMD':
145
146
        invs = [AMD(s, k) for s in crystals]
147
        compare_kwargs.pop('n_jobs', None)
148
        compare_kwargs.pop('verbose', None)
149
150
        if crystals_ is None:
151
            dm = AMD_pdist(invs, **compare_kwargs)
152
        else:
153
            invs_ = [AMD(s, k) for s in crystals_]
154
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
155
156
    elif by == 'PDD':
157
158
        invs = [PDD(s, k, **calc_kwargs) for s in crystals]
159
        compare_kwargs.pop('low_memory', None)
160
161
        if crystals_ is None:
162
            dm = PDD_pdist(invs, **compare_kwargs)
163
        else:
164
            invs_ = [PDD(s, k, **calc_kwargs) for s in crystals_]
165
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
166
167
    if nearest:
168
        nn_dm, inds = neighbours_from_distance_matrix(nearest, dm)
169
        data = {}
170
        for i in range(nearest):
171
            data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
172
            data['DIST ' + str(i+1)] = nn_dm[:, i]
173
        df = pd.DataFrame(data, index=names)
174
175
    else:
176
        if dm.ndim == 1:
177
            dm = squareform(dm)
178
        df = pd.DataFrame(dm, index=names, columns=names_)
179
180
    return df
181
182
183
def EMD(
184
        pdd: np.ndarray,
185
        pdd_: np.ndarray,
186
        metric: Optional[str] = 'chebyshev',
187
        return_transport: Optional[bool] = False,
188
        **kwargs):
189
    r"""Earth mover's distance (EMD) between two PDDs, aka the
190
    Wasserstein metric.
191
192
    Parameters
193
    ----------
194
    pdd : :class:`numpy.ndarray`
195
        PDD of a crystal.
196
    pdd\_ : :class:`numpy.ndarray`
197
        PDD of a crystal.
198
    metric : str or callable, default 'chebyshev'
199
        EMD between PDDs requires defining a distance between PDD rows.
200
        By default, Chebyshev (L-infinity) distance is chosen like with
201
        AMDs. Accepts any metric accepted by
202
        :func:`scipy.spatial.distance.cdist`.
203
    return_transport: bool, default False
204
        Instead return a tuple ``(emd, transport_plan)`` where
205
        transport_plan describes the optimal flow.
206
207
    Returns
208
    -------
209
    emd : float
210
        Earth mover's distance between two PDDs. If ``return_transport``
211
        is True, return a tuple (emd, transport_plan).
212
213
    Raises
214
    ------
215
    ValueError
216
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
217
        columns.
218
    """
219
220
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
221
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
222
223
    if return_transport:
224
        return emd_dist, transport_plan
225
226
    return emd_dist
227
228
229
def AMD_cdist(
230
        amds: Union[np.ndarray, List[np.ndarray]],
231
        amds_: Union[np.ndarray, List[np.ndarray]],
232
        metric: str = 'chebyshev',
233
        low_memory: bool = False,
234
        **kwargs
235
) -> np.ndarray:
236
    r"""Compare two sets of AMDs with each other, returning a distance
237
    matrix. This function is essentially
238
    :func:`scipy.spatial.distance.cdist` with the default metric
239
    ``chebyshev`` and a low memory option.
240
241
    Parameters
242
    ----------
243
    amds : array_like
244
        A list/array of AMDs.
245
    amds\_ : array_like
246
        A list/array of AMDs.
247
    metric : str or callable, default 'chebyshev'
248
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
249
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
250
    low_memory : bool, default False
251
        Use a slower but more memory efficient method for large collections of
252
        AMDs (metric 'chebyshev' only).
253
254
    Returns
255
    -------
256
    dm : :class:`numpy.ndarray`
257
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` is the
258
        distance (given by ``metric``) between ``amds[i]`` and ``amds[j]``.
259
    """
260
261
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
262
263
    if len(amds.shape) == 1:
264
        amds = np.array([amds])
265
    if len(amds_.shape) == 1:
266
        amds_ = np.array([amds_])
267
268
    if low_memory:
269
        if metric != 'chebyshev':
270
            msg = "Using only allowed metric 'chebyshev' for low_memory"
271
            warnings.warn(msg, UserWarning)
272
273
        dm = np.empty((len(amds), len(amds_)))
274
        for i, amd_vec in enumerate(amds):
275
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
276
    else:
277
        dm = cdist(amds, amds_, metric=metric, **kwargs)
278
279
    return dm
280
281
282
def AMD_pdist(
283
        amds: Union[np.ndarray, List[np.ndarray]],
284
        metric: str = 'chebyshev',
285
        low_memory: bool = False,
286
        **kwargs
287
) -> np.ndarray:
288
    """Compare a set of AMDs pairwise, returning a condensed distance
289
    matrix. This function is essentially
290
    :func:`scipy.spatial.distance.pdist` with the default metric
291
    ``chebyshev`` and a low memory parameter.
292
293
    Parameters
294
    ----------
295
    amds : array_like
296
        An list/array of AMDs.
297
    metric : str or callable, default 'chebyshev'
298
        Usually AMDs are compared with the Chebyshev (L-infinity)
299
        distance. Accepts any metric accepted by
300
        :func:`scipy.spatial.distance.pdist`.
301
    low_memory : bool, default False
302
        Use a slower but more memory efficient method for large
303
        collections of AMDs (metric 'chebyshev' only).
304
305
    Returns
306
    -------
307
    cdm : :class:`numpy.ndarray`
308
        Returns a condensed distance matrix. Collapses a square distance
309
        matrix into a vector, just keeping the upper half. See the
310
        function :func:`squareform <scipy.spatial.distance.squareform>`
311
        from SciPy to convert to a symmetric square distance matrix.
312
    """
313
314
    amds = np.asarray(amds)
315
316
    if len(amds.shape) == 1:
317
        amds = np.array([amds])
318
319
    if low_memory:
320
        m = len(amds)
321
        if metric != 'chebyshev':
322
            msg = "Using only allowed metric 'chebyshev' for low_memory"
323
            warnings.warn(msg, UserWarning)
324
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.double)
325
        ind = 0
326
        for i in range(m):
327
            ind_ = ind + m - i - 1
328
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
329
            ind = ind_
330
    else:
331
        cdm = pdist(amds, metric=metric, **kwargs)
332
333
    return cdm
334
335
336
def PDD_cdist(
337
        pdds: List[np.ndarray],
338
        pdds_: List[np.ndarray],
339
        metric: str = 'chebyshev',
340
        backend='multiprocessing',
341
        n_jobs=None,
342
        verbose=0,
343
        **kwargs
344
) -> np.ndarray:
345
    r"""Compare two sets of PDDs with each other, returning a distance
346
    matrix. Supports parallel processing via joblib. If using
347
    parallelisation, make sure to include a if __name__ == '__main__'
348
    guard around this function.
349
350
    Parameters
351
    ----------
352
    pdds : List[:class:`numpy.ndarray`]
353
        A list of PDDs.
354
    pdds\_ : List[:class:`numpy.ndarray`]
355
        A list of PDDs.
356
    metric : str or callable, default 'chebyshev'
357
        Usually PDD rows are compared with the Chebyshev/l-infinity
358
        distance. Accepts any metric accepted by
359
        :func:`scipy.spatial.distance.cdist`.
360
    n_jobs : int, default None
361
        Maximum number of concurrent jobs for parallel processing with
362
        ``joblib``. Set to -1 to use the maximum. Using parallel
363
        processing may be slower for small inputs.
364
    verbose : int, default 0
365
        Verbosity level. If using parallel processing (n_jobs > 1),
366
        verbose is passed to :class:`joblib.Parallel` and larger values
367
        = more verbose.
368
    backend : str, default 'multiprocessing'
369
        The parallelization backend implementation. For a list of
370
        supported backends, see the backend argument of
371
        :class:`joblib.Parallel`.
372
373
    Returns
374
    -------
375
    dm : :class:`numpy.ndarray`
376
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
377
        :math:`ij` th entry is the distance between ``pdds[i]`` and
378
        ``pdds_[j]`` given by Earth mover's distance.
379
    """
380
381
    if isinstance(pdds, np.ndarray):
382
        if len(pdds.shape) == 2:
383
            pdds = [pdds]
384
385
    if isinstance(pdds_, np.ndarray):
386
        if len(pdds_.shape) == 2:
387
            pdds_ = [pdds_]
388
389
    kwargs.pop('return_transport', None)
390
391
    if n_jobs is not None and n_jobs not in (0, 1):
392
        # TODO: put results into preallocated empty array in place
393
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
394
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
395
            for i in range(len(pdds)) for j in range(len(pdds_))
396
        )
397
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
398
399
    else:
400
        n, m = len(pdds), len(pdds_)
401
        dm = np.empty((n, m))
402
        if verbose:
403
            pbar = tqdm.tqdm(total=n*m)
404
        for i in range(n):
405
            for j in range(m):
406
                dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
407
                if verbose:
408
                    pbar.update(1)
409
        if verbose:
410
            pbar.close()
411
412
    return dm
413
414
415
def PDD_pdist(
416
        pdds: List[np.ndarray],
417
        metric: str = 'chebyshev',
418
        n_jobs=None,
419
        verbose=0,
420
        backend='multiprocessing',
421
        **kwargs
422
) -> np.ndarray:
423
    """Compare a set of PDDs pairwise, returning a condensed distance
424
    matrix. Supports parallelisation via joblib. If using
425
    parallelisation, make sure to include a if __name__ == '__main__'
426
    guard around this function.
427
428
    Parameters
429
    ----------
430
    pdds : List[:class:`numpy.ndarray`]
431
        A list of PDDs.
432
    metric : str or callable, default 'chebyshev'
433
        Usually PDD rows are compared with the Chebyshev/l-infinity
434
        distance. Accepts any metric accepted by
435
        :func:`scipy.spatial.distance.cdist`.
436
    n_jobs : int, default None
437
        Maximum number of concurrent jobs for parallel processing with
438
        ``joblib``. Set to -1 to use the maximum. Using parallel
439
        processing may be slower for small inputs.
440
    verbose : int, default 0
441
        Verbosity level. If using parallel processing (n_jobs > 1),
442
        verbose is passed to :class:`joblib.Parallel` and larger values
443
        = more verbose.
444
    backend : str, default 'multiprocessing'
445
        The parallelization backend implementation. For a list of
446
        supported backends, see the backend argument of
447
        :class:`joblib.Parallel`.
448
449
    Returns
450
    -------
451
    cdm : :class:`numpy.ndarray`
452
        Returns a condensed distance matrix. Collapses a square distance
453
        matrix into a vector, just keeping the upper half. See the
454
        function :func:`squareform <scipy.spatial.distance.squareform>`
455
        from SciPy to convert to a symmetric square distance matrix.
456
    """
457
458
    kwargs.pop('return_transport', None)
459
460
    if n_jobs is not None and n_jobs > 1:
461
        # TODO: put results into preallocated empty array in place
462
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)(
463
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
464
            for i, j in combinations(range(len(pdds)), 2)
465
        )
466
        cdm = np.array(cdm)
467
468
    else:
469
        m = len(pdds)
470
        cdm_len = (m * (m - 1)) // 2
471
        cdm = np.empty(cdm_len, dtype=np.double)
472
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
473
        if verbose:
474
            eta = tqdm.tqdm(cdm_len)
475
        for r, (i, j) in enumerate(inds):
476
            cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
477
            if verbose:
478
                eta.update(1)
479
        if verbose:
480
            eta.close()
481
    return cdm
482
483
484
def emd(pdd: np.ndarray, pdd_: np.ndarray, **kwargs):
485
    """Alias for :func:`EMD() <.compare.EMD>`."""
486
    return EMD(pdd, pdd_, **kwargs)
487
488
489
def _unwrap_periodicset_list(psets_or_str, **reader_kwargs):
490
    """Valid input for compare (``PeriodicSet``, path, refcode, lists of
491
    such) --> list of PeriodicSets.
492
    """
493
494
    if isinstance(psets_or_str, PeriodicSet):
495
        return [psets_or_str]
496
    if isinstance(psets_or_str, list):
497
        return [s for item in psets_or_str
498
                for s in _extract_periodicsets(item, **reader_kwargs)]
499
    return _extract_periodicsets(psets_or_str, **reader_kwargs)
500
501
502
def _extract_periodicsets(item, **reader_kwargs):
503
    """str (path/refocde), file or ``PeriodicSet`` --> list of
504
    ``PeriodicSets``.
505
    """
506
507
    if isinstance(item, PeriodicSet):
508
        return [item]
509
    if isinstance(item, str) and not os.path.isfile(item) and not os.path.isdir(item):
510
        reader_kwargs.pop('reader', None)
511
        return list(CSDReader(item, **reader_kwargs))
512
    reader_kwargs.pop('families', None)
513
    reader_kwargs.pop('refcodes', None)
514
    return list(CifReader(item, **reader_kwargs))
515