Test Failed
Push — master ( baece5...0f0bce )
by Daniel
07:51
created

amd.compare.PDD_pdist()   B

Complexity

Conditions 7

Size

Total Lines 72
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 29
dl 0
loc 72
rs 7.784
c 0
b 0
f 0
cc 7
nop 6

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Optional, Union, Tuple
5
from functools import partial
6
from itertools import combinations
7
from pathlib import Path
8
9
import numpy as np
10
import pandas as pd
11
from scipy.spatial.distance import cdist, pdist, squareform
12
from joblib import Parallel, delayed
13
import tqdm
14
15
from .io import CifReader, CSDReader
16
from .calculate import AMD, PDD
17
from ._emd import network_simplex
18
from .periodicset import PeriodicSet
19
from .utils import neighbours_from_distance_matrix
20
21
__all__ = [
22
    'compare', 'EMD', 'AMD_cdist', 'AMD_pdist', 'PDD_cdist', 'PDD_pdist',
23
    'emd'
24
]
25
26
27
def compare(
28
        crystals,
29
        crystals_=None,
30
        by: str = 'AMD',
31
        k: int = 100,
32
        nearest: Optional[int] = None,
33
        reader: str = 'gemmi',
34
        remove_hydrogens: bool = False,
35
        disorder: str = 'skip',
36
        heaviest_component: bool = False,
37
        molecular_centres: bool = False,
38
        csd_refcodes: bool = False,
39
        refcode_families: bool = False,
40
        show_warnings: bool = True,
41
        collapse_tol: float = 1e-4,
42
        metric: str = 'chebyshev',
43
        n_jobs: Optional[int] = None,
44
        backend: str = 'multiprocessing',
45
        verbose: bool = False,
46
        low_memory: bool = False,
47
        **kwargs
48
) -> pd.DataFrame:
49
    r"""Given one or two sets of crystals, compare by AMD or PDD and
50
    return a pandas DataFrame of the distance matrix.
51
52
    Given one or two paths to CIFs, periodic sets, CSD refcodes or lists
53
    thereof, compare by AMD or PDD and return a pandas DataFrame of the
54
    distance matrix. Default is to comapre by AMD with k = 100. Accepts
55
    most keyword arguments accepted by
56
    :class:`CifReader <.io.CifReader>`,
57
    :class:`CSDReader <.io.CSDReader>` and functions from
58
    :mod:`.compare`.
59
60
    Parameters
61
    ----------
62
    crystals : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`
63
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
64
        or a list of those.
65
    crystals\_ : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`, optional
66
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
67
        or a list of those.
68
    by : str, default 'AMD'
69
        Use AMD or PDD to compare crystals.
70
    k : int, default 100
71
        Parameter for AMD/PDD, the number of neighbour atoms to consider
72
        for each atom in a unit cell.
73
    nearest : int, deafult None
74
        Find a number of nearest neighbours instead of a full distance
75
        matrix between crystals.
76
    reader : str, optional
77
        The backend package used to parse the CIF. The default is
78
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
79
        accepted, as well as :code:`ccdc` if csd-python-api is
80
        installed. The ccdc reader should be able to read any format
81
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
82
        been tested.
83
    remove_hydrogens : bool, optional
84
        Remove hydrogens from the crystals.
85
    disorder : str, optional
86
        Controls how disordered structures are handled. Default is
87
        ``skip`` which skips any crystal with disorder, since disorder
88
        conflicts with the periodic set model. To read disordered
89
        structures anyway, choose either :code:`ordered_sites` to remove
90
        atoms with disorder or :code:`all_sites` include all atoms
91
        regardless of disorder.
92
    heaviest_component : bool, optional, csd-python-api only
93
        Removes all but the heaviest molecule in
94
        the asymmeric unit, intended for removing solvents.
95
    molecular_centres : bool, default False, csd-python-api only
96
        Use the centres of molecules for comparison
97
        instead of centres of atoms.
98
    csd_refcodes : bool, optional, csd-python-api only
99
        Interpret ``crystals`` and ``crystals_`` as CSD refcodes or
100
        lists thereof, rather than paths.
101
    refcode_families : bool, optional, csd-python-api only
102
        Read all entries whose refcode starts with
103
        the given strings, or 'families' (e.g. giving 'DEBXIT' reads all
104
        entries with refcodes starting with DEBXIT).
105
    show_warnings : bool, optional
106
        Controls whether warnings that arise during reading are printed.
107
    collapse_tol: float, default 1e-4, ``by='PDD'`` only
108
        If two PDD rows have all elements closer
109
        than ``collapse_tol``, they are merged and weights are given to
110
        rows in proportion to the number of times they appeared.
111
    metric : str or callable, default 'chebyshev'
112
        The metric to compare AMDs/PDDs with. AMDs are compared directly
113
        with this metric. EMD is the metric used between PDDs, which
114
        requires giving a metric to use between PDD rows. Chebyshev
115
        (L-infinity) distance is the default. Accepts any metric
116
        accepted by :func:`scipy.spatial.distance.cdist`.
117
    n_jobs : int, default None, ``by='PDD'`` only
118
        Maximum number of concurrent jobs for
119
        parallel processing with :code:`joblib`. Set to -1 to use the
120
        maximum. Using parallel processing may be slower for small
121
        inputs.
122
    backend : str, default 'multiprocessing', ``by='PDD'`` only
123
        The parallelization backend implementation for PDD comparisons.
124
        For a list of supported backends, see the backend argument of
125
        :class:`joblib.Parallel`.
126
    verbose : bool, default False
127
        Prints a progress bar when reading crystals, calculating
128
        AMDs/PDDs and comparing PDDs. If using parallel processing
129
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
130
        is used, otherwise uses ``tqdm``.
131
    low_memory : bool, default False, ``by='AMD'`` only
132
        Use a slower but more memory efficient
133
        method for large collections of AMDs (metric 'chebyshev' only).
134
135
    Returns
136
    -------
137
    df : :class:`pandas.DataFrame`
138
        DataFrame of the distance matrix for the given crystals compared
139
        by the chosen invariant.
140
141
    Raises
142
    ------
143
    ValueError
144
        If by is not 'AMD' or 'PDD', if either set given have no valid
145
        crystals to compare, or if crystals or crystals\_ are an invalid
146
        type.
147
148
    Examples
149
    --------
150
    Compare everything in a .cif (deafult, AMD with k=100)::
151
152
        df = amd.compare('data.cif')
153
154
    Compare everything in one cif with all crystals in all cifs in a
155
    directory (PDD, k=50)::
156
157
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
158
159
    **Examples (csd-python-api only)**
160
161
    Compare two crystals by CSD refcode (PDD, k=50)::
162
163
        df = amd.compare('DEBXIT01', 'DEBXIT02', csd_refcodes=True, by='PDD', k=50)
164
165
    Compare everything in a refcode family (AMD, k=100)::
166
167
        df = amd.compare('DEBXIT', csd_refcodes=True, families=True)
168
    """
169
170
    by = by.upper()
171
    if by not in ('AMD', 'PDD'):
172
        raise ValueError(
173
            "'by' parameter of amd.compare() must be 'AMD' or 'PDD' (passed "
174
            f"'{by}')"
175
        )
176
177
    if heaviest_component or molecular_centres:
178
        reader = 'ccdc'
179
180
    if refcode_families:
181
        csd_refcodes = True
182
183
    reader_kwargs = {
184
        'reader': reader,
185
        'families': refcode_families,
186
        'remove_hydrogens': remove_hydrogens,
187
        'disorder': disorder,
188
        'heaviest_component': heaviest_component,
189
        'molecular_centres': molecular_centres,
190
        'show_warnings': show_warnings,
191
        'verbose': verbose,
192
    }
193
194
    compare_kwargs = {
195
        'metric': metric,
196
        'n_jobs': n_jobs,
197
        'backend': backend,
198
        'verbose': verbose,
199
        'low_memory': low_memory,
200
        **kwargs
201
    }
202
203
    # Get list(s) of periodic sets from first input
204
    if csd_refcodes:
205
        crystals = _unwrap_refcode_list(crystals, **reader_kwargs)
206
    else:
207
        crystals = _unwrap_pset_list(crystals, **reader_kwargs)
208
209
    if not crystals:
210
        raise ValueError(
211
            'First argument passed to amd.compare() contains no valid '
212
            'crystals/periodic sets'
213
        )
214
    names = [s.name for s in crystals]
215
    if verbose:
216
        container = tqdm.tqdm(crystals, desc='Calculating', delay=1)
217
    else:
218
        container = crystals
219
220
    # Get list(s) of periodic sets from second input if given
221
    if crystals_ is None:
222
        names_ = names
223
        container_ = None
224
    else:
225
        if csd_refcodes:
226
            crystals_ = _unwrap_refcode_list(crystals_, **reader_kwargs)
227
        else:
228
            crystals_ = _unwrap_pset_list(crystals_, **reader_kwargs)
229
        if not crystals_:
230
            raise ValueError(
231
                'Second argument passed to amd.compare() contains no '
232
                'valid crystals/periodic sets'
233
            )
234
        names_ = [s.name for s in crystals_]
235
        if verbose:
236
            container_ = tqdm.tqdm(crystals_, desc='Calculating', delay=1)
237
        else:
238
            container_ = crystals_
239
240
    if by == 'AMD':
241
        invs = [AMD(s, k) for s in container]
242
        if isinstance(container, tqdm.tqdm):
243
            container.close()
244
        compare_kwargs.pop('n_jobs', None)
245
        compare_kwargs.pop('backend', None)
246
        compare_kwargs.pop('verbose', None)
247
248
        if crystals_ is None:
249
            dm = AMD_pdist(invs, **compare_kwargs)
250
        else:
251
            invs_ = [AMD(s, k) for s in container_]
252
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
253
254
    elif by == 'PDD':
255
        invs = [PDD(s, k, collapse_tol=collapse_tol) for s in container]
256
        compare_kwargs.pop('low_memory', None)
257
258
        if crystals_ is None:
259
            dm = PDD_pdist(invs, **compare_kwargs)
260
        else:
261
            invs_ = [PDD(s, k, collapse_tol=collapse_tol) for s in container_]
262
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
263
264
    if nearest:
265
        nn_dm, inds = neighbours_from_distance_matrix(nearest, dm)
0 ignored issues
show
introduced by
The variable dm does not seem to be defined for all execution paths.
Loading history...
266
        data = {}
267
        for i in range(nearest):
268
            data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable str does not seem to be defined.
Loading history...
269
            data['DIST ' + str(i+1)] = nn_dm[:, i]
270
        df = pd.DataFrame(data, index=names)
271
    else:
272
        if dm.ndim == 1:
273
            dm = squareform(dm)
274
        df = pd.DataFrame(dm, index=names, columns=names_)
275
276
    return df
277
278
279
def EMD(
280
        pdd: np.ndarray,
281
        pdd_: np.ndarray,
282
        metric: Optional[str] = 'chebyshev',
283
        return_transport: Optional[bool] = False,
284
        **kwargs
285
) -> Union[float, Tuple[float, np.ndarray[np.float64]]]:
286
    r"""Calculate the Earth mover's distance (EMD) between two PDDs, aka
287
    the Wasserstein metric.
288
289
    Parameters
290
    ----------
291
    pdd : :class:`numpy.ndarray`
292
        PDD of a crystal.
293
    pdd\_ : :class:`numpy.ndarray`
294
        PDD of a crystal.
295
    metric : str or callable, default 'chebyshev'
296
        EMD between PDDs requires defining a distance between PDD rows.
297
        By default, Chebyshev (L-infinity) distance is chosen like with
298
        AMDs. Accepts any metric accepted by
299
        :func:`scipy.spatial.distance.cdist`.
300
    return_transport: bool, default False
301
        Instead return a tuple ``(emd, transport_plan)`` where
302
        transport_plan describes the optimal flow.
303
304
    Returns
305
    -------
306
    emd : float
307
        Earth mover's distance between two PDDs. If ``return_transport``
308
        is True, return a tuple (emd, transport_plan).
309
310
    Raises
311
    ------
312
    ValueError
313
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
314
        columns.
315
    """
316
317
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
318
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
319
320
    if return_transport:
321
        return emd_dist, transport_plan
322
    return emd_dist
323
324
325
def AMD_cdist(
326
        amds,
327
        amds_,
328
        metric: str = 'chebyshev',
329
        low_memory: bool = False,
330
        **kwargs
331
) -> np.ndarray:
332
    r"""Compare two sets of AMDs with each other, returning a distance
333
    matrix. This function is essentially
334
    :func:`scipy.spatial.distance.cdist` with the default metric
335
    ``chebyshev`` and a low memory option.
336
337
    Parameters
338
    ----------
339
    amds : ArrayLike
340
        A list/array of AMDs.
341
    amds\_ : ArrayLike
342
        A list/array of AMDs.
343
    metric : str or callable, default 'chebyshev'
344
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
345
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
346
    low_memory : bool, default False
347
        Use a slower but more memory efficient method for large collections of
348
        AMDs (metric 'chebyshev' only).
349
350
    Returns
351
    -------
352
    dm : :class:`numpy.ndarray`
353
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` is the
354
        distance (given by ``metric``) between ``amds[i]`` and ``amds[j]``.
355
    """
356
357
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
358
359
    if len(amds.shape) == 1:
360
        amds = np.array([amds])
361
    if len(amds_.shape) == 1:
362
        amds_ = np.array([amds_])
363
364
    if low_memory:
365
        if metric != 'chebyshev':
366
            raise ValueError(
367
                "'low_memory' parameter of amd.AMD_cdist() only implemented "
368
                "with metric='chebyshev'."
369
            )
370
        dm = np.empty((len(amds), len(amds_)))
371
        for i, amd_vec in enumerate(amds):
372
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
373
    else:
374
        dm = cdist(amds, amds_, metric=metric, **kwargs)
375
376
    return dm
377
378
379
def AMD_pdist(
380
        amds,
381
        metric: str = 'chebyshev',
382
        low_memory: bool = False,
383
        **kwargs
384
) -> np.ndarray:
385
    """Compare a set of AMDs pairwise, returning a condensed distance
386
    matrix. This function is essentially
387
    :func:`scipy.spatial.distance.pdist` with the default metric
388
    ``chebyshev`` and a low memory parameter.
389
390
    Parameters
391
    ----------
392
    amds : ArrayLike
393
        An list/array of AMDs.
394
    metric : str or callable, default 'chebyshev'
395
        Usually AMDs are compared with the Chebyshev (L-infinity)
396
        distance. Accepts any metric accepted by
397
        :func:`scipy.spatial.distance.pdist`.
398
    low_memory : bool, default False
399
        Use a slower but more memory efficient method for large
400
        collections of AMDs (metric 'chebyshev' only).
401
402
    Returns
403
    -------
404
    cdm : :class:`numpy.ndarray`
405
        Returns a condensed distance matrix. Collapses a square distance
406
        matrix into a vector, just keeping the upper half. See the
407
        function :func:`squareform <scipy.spatial.distance.squareform>`
408
        from SciPy to convert to a symmetric square distance matrix.
409
    """
410
411
    amds = np.asarray(amds)
412
413
    if len(amds.shape) == 1:
414
        amds = np.array([amds])
415
416
    if low_memory:
417
        m = len(amds)
418
        if metric != 'chebyshev':
419
            raise ValueError(
420
                "'low_memory' parameter of amd.AMD_pdist() only implemented "
421
                "with metric='chebyshev'."
422
            )
423
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.float64)
424
        ind = 0
425
        for i in range(m):
426
            ind_ = ind + m - i - 1
427
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
428
            ind = ind_
429
    else:
430
        cdm = pdist(amds, metric=metric, **kwargs)
431
432
    return cdm
433
434
435
def PDD_cdist(
436
        pdds: List[np.ndarray],
437
        pdds_: List[np.ndarray],
438
        metric: str = 'chebyshev',
439
        backend: str = 'multiprocessing',
440
        n_jobs: Optional[int] = None,
441
        verbose: bool = False,
442
        **kwargs
443
) -> np.ndarray:
444
    r"""Compare two sets of PDDs with each other, returning a distance
445
    matrix. Supports parallel processing via joblib. If using
446
    parallelisation, make sure to include an if __name__ == '__main__'
447
    guard around this function.
448
449
    Parameters
450
    ----------
451
    pdds : List[:class:`numpy.ndarray`]
452
        A list of PDDs.
453
    pdds\_ : List[:class:`numpy.ndarray`]
454
        A list of PDDs.
455
    metric : str or callable, default 'chebyshev'
456
        Usually PDD rows are compared with the Chebyshev/l-infinity
457
        distance. Accepts any metric accepted by
458
        :func:`scipy.spatial.distance.cdist`.
459
    backend : str, default 'multiprocessing'
460
        The parallelization backend implementation. For a list of
461
        supported backends, see the backend argument of
462
        :class:`joblib.Parallel`.
463
    n_jobs : int, default None
464
        Maximum number of concurrent jobs for parallel processing with
465
        ``joblib``. Set to -1 to use the maximum. Using parallel
466
        processing may be slower for small inputs.
467
    verbose : bool, default False
468
        Prints a progress bar. If using parallel processing
469
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
470
        is used, otherwise uses tqdm.
471
472
    Returns
473
    -------
474
    dm : :class:`numpy.ndarray`
475
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
476
        :math:`ij` th entry is the distance between ``pdds[i]`` and
477
        ``pdds_[j]`` given by Earth mover's distance.
478
    """
479
480
    kwargs.pop('return_transport', None)
481
    k = pdds[0].shape[-1] - 1
482
    _verbose = 3 if verbose else 0
483
484
    if n_jobs is not None and n_jobs not in (0, 1):
485
        # TODO: put results into preallocated empty array in place
486
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
487
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
488
            for i in range(len(pdds)) for j in range(len(pdds_))
489
        )
490
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
491
492
    else:
493
        n, m = len(pdds), len(pdds_)
494
        dm = np.empty((n, m))
495
        if verbose:
496
            desc = f'Comparing {len(pdds)}x{len(pdds_)} PDDs (k={k})'
497
            progress_bar = tqdm.tqdm(desc=desc, total=n*m)
498
            for i in range(n):
499
                for j in range(m):
500
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
501
                    progress_bar.update(1)
502
            progress_bar.close()
503
        else:
504
            for i in range(n):
505
                for j in range(m):
506
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
507
508
    return dm
509
510
511
def PDD_pdist(
512
        pdds: List[np.ndarray],
513
        metric: str = 'chebyshev',
514
        backend: str = 'multiprocessing',
515
        n_jobs: Optional[int] = None,
516
        verbose: bool = False,
517
        **kwargs
518
) -> np.ndarray:
519
    """Compare a set of PDDs pairwise, returning a condensed distance
520
    matrix. Supports parallelisation via joblib. If using
521
    parallelisation, make sure to include a if __name__ == '__main__'
522
    guard around this function.
523
524
    Parameters
525
    ----------
526
    pdds : List[:class:`numpy.ndarray`]
527
        A list of PDDs.
528
    metric : str or callable, default 'chebyshev'
529
        Usually PDD rows are compared with the Chebyshev/l-infinity
530
        distance. Accepts any metric accepted by
531
        :func:`scipy.spatial.distance.cdist`.
532
    backend : str, default 'multiprocessing'
533
        The parallelization backend implementation. For a list of
534
        supported backends, see the backend argument of
535
        :class:`joblib.Parallel`.
536
    n_jobs : int, default None
537
        Maximum number of concurrent jobs for parallel processing with
538
        ``joblib``. Set to -1 to use the maximum. Using parallel
539
        processing may be slower for small inputs.
540
    verbose : bool, default False
541
        Prints a progress bar. If using parallel processing
542
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
543
        is used, otherwise uses tqdm.
544
545
    Returns
546
    -------
547
    cdm : :class:`numpy.ndarray`
548
        Returns a condensed distance matrix. Collapses a square distance
549
        matrix into a vector, just keeping the upper half. See the
550
        function :func:`squareform <scipy.spatial.distance.squareform>`
551
        from SciPy to convert to a symmetric square distance matrix.
552
    """
553
554
    kwargs.pop('return_transport', None)
555
    k = pdds[0].shape[-1] - 1
556
    _verbose = 3 if verbose else 0
557
558
    if n_jobs is not None and n_jobs > 1:
559
        # TODO: put results into preallocated empty array in place
560
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
561
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
562
            for i, j in combinations(range(len(pdds)), 2)
563
        )
564
        cdm = np.array(cdm)
565
566
    else:
567
        m = len(pdds)
568
        cdm_len = (m * (m - 1)) // 2
569
        cdm = np.empty(cdm_len, dtype=np.float64)
570
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable j does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
571
        if verbose:
572
            desc = f'Comparing {len(pdds)} PDDs pairwise (k={k})'
573
            progress_bar = tqdm.tqdm(desc=desc, total=cdm_len)
574
            for r, (i, j) in enumerate(inds):
575
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
576
                progress_bar.update(1)
577
            progress_bar.close()
578
        else:
579
            for r, (i, j) in enumerate(inds):
580
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
581
582
    return cdm
583
584
585
def emd(
586
        pdd: np.ndarray, pdd_: np.ndarray, **kwargs
587
) -> Union[float, Tuple[float, np.ndarray[np.float64]]]:
588
    """Alias for :func:`EMD() <.compare.EMD>`."""
589
    return EMD(pdd, pdd_, **kwargs)
590
591
592
def _unwrap_refcode_list(refcodes, **reader_kwargs):
593
    """Given string or list of strings, interpret as CSD refcodes and
594
    return a list of PeriodicSets.
595
    """
596
597
    reader_kwargs.pop('reader', None)
598
    if isinstance(refcodes, list):
599
        if not all(isinstance(refcode, str) for refcode in refcodes):
600
            raise TypeError(
601
                f'amd.compare(refcodes=True) expects a string or list of '
602
                'strings.'
603
            )
604
    elif not isinstance(refcodes, str):
605
        raise TypeError(
606
            f'amd.compare(refcodes=True) expects a string or list of '
607
            f'strings, got {refcodes.__class__.__name__}'
608
        )
609
    return list(CSDReader(refcodes, **reader_kwargs))
610
611
612
def _unwrap_pset_list(psets, **reader_kwargs):
613
    """Given a valid input for amd.compare(), return a list of
614
    PeriodicSets. Accepts paths, PeriodicSets, tuples or lists
615
    thereof."""
616
617
    def _extract_periodicsets(item, **reader_kwargs):
618
        """Given a path, PeriodicSet or tuple, return a list of the
619
        PeriodicSet(s)."""
620
621
        if isinstance(item, PeriodicSet):
622
            return [item]
623
        if isinstance(item, Tuple):
624
            return [PeriodicSet(item[0], item[1])]
625
        try:
626
            path = Path(item)
627
        except TypeError:
628
            raise ValueError(
629
                'amd.compare() expects a string, amd.PeriodicSet or tuple, '
630
                f'got {item.__class__.__name__}'
631
            )
632
        return list(CifReader(path, **reader_kwargs))
633
634
    reader_kwargs.pop('families', None)
635
    if isinstance(psets, list):
636
        return [s for i in psets
637
                for s in _extract_periodicsets(i, **reader_kwargs)]
638
    return _extract_periodicsets(psets, **reader_kwargs)
639