Passed
Push — master ( a617b6...34fcd6 )
by Daniel
07:54
created

amd.compare.AMD_cdist()   B

Complexity

Conditions 6

Size

Total Lines 52
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 52
rs 8.4666
c 0
b 0
f 0
cc 6
nop 5

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Optional, Union, Tuple
5
from functools import partial
6
from itertools import combinations
7
from pathlib import Path
8
9
import numpy as np
10
import pandas as pd
11
from scipy.spatial.distance import cdist, pdist, squareform
12
from joblib import Parallel, delayed
13
import tqdm
14
15
from .io import CifReader, CSDReader
16
from .calculate import AMD, PDD
17
from ._emd import network_simplex
18
from .periodicset import PeriodicSet
19
from .utils import neighbours_from_distance_matrix
20
21
__all__ = [
22
    'compare',
23
    'EMD',
24
    'AMD_cdist',
25
    'AMD_pdist',
26
    'PDD_cdist',
27
    'PDD_pdist',
28
    'emd'
29
]
30
31
32
def compare(
33
        crystals,
34
        crystals_=None,
35
        by: str = 'AMD',
36
        k: int = 100,
37
        nearest: Optional[int] = None,
38
        reader: str = 'gemmi',
39
        remove_hydrogens: bool = False,
40
        disorder: str = 'skip',
41
        heaviest_component: bool = False,
42
        molecular_centres: bool = False,
43
        csd_refcodes: bool = False,
44
        refcode_families: bool = False,
45
        show_warnings: bool = True,
46
        collapse_tol: float = 1e-4,
47
        metric: str = 'chebyshev',
48
        n_jobs: Optional[int] = None,
49
        backend: str = 'multiprocessing',
50
        verbose: bool = False,
51
        low_memory: bool = False,
52
        **kwargs
53
) -> pd.DataFrame:
54
    r"""Given one or two sets of crystals, compare by AMD or PDD and
55
    return a pandas DataFrame of the distance matrix.
56
57
    Given one or two paths to CIFs, periodic sets, CSD refcodes or lists
58
    thereof, compare by AMD or PDD and return a pandas DataFrame of the
59
    distance matrix. Default is to comapre by AMD with k = 100. Accepts
60
    most keyword arguments accepted by
61
    :class:`CifReader <.io.CifReader>`,
62
    :class:`CSDReader <.io.CSDReader>` and functions from
63
    :mod:`.compare`.
64
65
    Parameters
66
    ----------
67
    crystals : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`
68
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
69
        or a list of those.
70
    crystals\_ : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`, optional
71
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
72
        or a list of those.
73
    by : str, default 'AMD'
74
        Use AMD or PDD to compare crystals.
75
    k : int, default 100
76
        Parameter for AMD/PDD, the number of neighbour atoms to consider
77
        for each atom in a unit cell.
78
    nearest : int, deafult None
79
        Find a number of nearest neighbours instead of a full distance
80
        matrix between crystals.
81
    reader : str, optional
82
        The backend package used to parse the CIF. The default is
83
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
84
        accepted, as well as :code:`ccdc` if csd-python-api is
85
        installed. The ccdc reader should be able to read any format
86
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
87
        been tested.
88
    remove_hydrogens : bool, optional
89
        Remove hydrogens from the crystals.
90
    disorder : str, optional
91
        Controls how disordered structures are handled. Default is
92
        ``skip`` which skips any crystal with disorder, since disorder
93
        conflicts with the periodic set model. To read disordered
94
        structures anyway, choose either :code:`ordered_sites` to remove
95
        atoms with disorder or :code:`all_sites` include all atoms
96
        regardless of disorder.
97
    heaviest_component : bool, optional, csd-python-api only
98
        Removes all but the heaviest molecule in
99
        the asymmeric unit, intended for removing solvents.
100
    molecular_centres : bool, default False, csd-python-api only
101
        Use the centres of molecules for comparison
102
        instead of centres of atoms.
103
    csd_refcodes : bool, optional, csd-python-api only
104
        Interpret ``crystals`` and ``crystals_`` as CSD refcodes or
105
        lists thereof, rather than paths.
106
    refcode_families : bool, optional, csd-python-api only
107
        Read all entries whose refcode starts with
108
        the given strings, or 'families' (e.g. giving 'DEBXIT' reads all
109
        entries with refcodes starting with DEBXIT).
110
    show_warnings : bool, optional
111
        Controls whether warnings that arise during reading are printed.
112
    collapse_tol: float, default 1e-4, ``by='PDD'`` only
113
        If two PDD rows have all elements closer
114
        than ``collapse_tol``, they are merged and weights are given to
115
        rows in proportion to the number of times they appeared.
116
    metric : str or callable, default 'chebyshev'
117
        The metric to compare AMDs/PDDs with. AMDs are compared directly
118
        with this metric. EMD is the metric used between PDDs, which
119
        requires giving a metric to use between PDD rows. Chebyshev
120
        (L-infinity) distance is the default. Accepts any metric
121
        accepted by :func:`scipy.spatial.distance.cdist`.
122
    n_jobs : int, default None, ``by='PDD'`` only
123
        Maximum number of concurrent jobs for
124
        parallel processing with :code:`joblib`. Set to -1 to use the
125
        maximum. Using parallel processing may be slower for small
126
        inputs.
127
    backend : str, default 'multiprocessing', ``by='PDD'`` only
128
        The parallelization backend implementation for PDD comparisons.
129
        For a list of supported backends, see the backend argument of
130
        :class:`joblib.Parallel`.
131
    verbose : bool, default False
132
        Prints a progress bar when reading crystals, calculating
133
        AMDs/PDDs and comparing PDDs. If using parallel processing
134
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
135
        is used, otherwise uses ``tqdm``.
136
    low_memory : bool, default False, ``by='AMD'`` only
137
        Use a slower but more memory efficient
138
        method for large collections of AMDs (metric 'chebyshev' only).
139
140
    Returns
141
    -------
142
    df : :class:`pandas.DataFrame`
143
        DataFrame of the distance matrix for the given crystals compared
144
        by the chosen invariant.
145
146
    Raises
147
    ------
148
    ValueError
149
        If by is not 'AMD' or 'PDD', if either set given have no valid
150
        crystals to compare, or if crystals or crystals\_ are an invalid
151
        type.
152
153
    Examples
154
    --------
155
    Compare everything in a .cif (deafult, AMD with k=100)::
156
157
        df = amd.compare('data.cif')
158
159
    Compare everything in one cif with all crystals in all cifs in a
160
    directory (PDD, k=50)::
161
162
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
163
164
    **Examples (csd-python-api only)**
165
166
    Compare two crystals by CSD refcode (PDD, k=50)::
167
168
        df = amd.compare('DEBXIT01', 'DEBXIT02', csd_refcodes=True, by='PDD', k=50)
169
170
    Compare everything in a refcode family (AMD, k=100)::
171
172
        df = amd.compare('DEBXIT', csd_refcodes=True, families=True)
173
    """
174
175
    by = by.upper()
176
    if by not in ('AMD', 'PDD'):
177
        raise ValueError(
178
            "'by' parameter of amd.compare() must be 'AMD' or 'PDD' (passed "
179
            f"'{by}')"
180
        )
181
182
    if heaviest_component or molecular_centres:
183
        reader = 'ccdc'
184
185
    if refcode_families:
186
        csd_refcodes = True
187
188
    reader_kwargs = {
189
        'reader': reader,
190
        'families': refcode_families,
191
        'remove_hydrogens': remove_hydrogens,
192
        'disorder': disorder,
193
        'heaviest_component': heaviest_component,
194
        'molecular_centres': molecular_centres,
195
        'show_warnings': show_warnings,
196
        'verbose': verbose,
197
    }
198
199
    compare_kwargs = {
200
        'metric': metric,
201
        'n_jobs': n_jobs,
202
        'backend': backend,
203
        'verbose': verbose,
204
        'low_memory': low_memory,
205
        **kwargs
206
    }
207
208
    # Get list(s) of periodic sets from first input
209
    if csd_refcodes:
210
        crystals = _unwrap_refcode_list(crystals, **reader_kwargs)
211
    else:
212
        crystals = _unwrap_pset_list(crystals, **reader_kwargs)
213
214
    if not crystals:
215
        raise ValueError(
216
            'First argument passed to amd.compare() contains no valid '
217
            'crystals/periodic sets'
218
        )
219
    names = [s.name for s in crystals]
220
    if verbose:
221
        container = tqdm.tqdm(crystals, desc='Calculating', delay=1)
222
    else:
223
        container = crystals
224
225
    # Get list(s) of periodic sets from second input if given
226
    if crystals_ is None:
227
        names_ = names
228
        container_ = None
229
    else:
230
        if csd_refcodes:
231
            crystals_ = _unwrap_refcode_list(crystals_, **reader_kwargs)
232
        else:
233
            crystals_ = _unwrap_pset_list(crystals_, **reader_kwargs)
234
        if not crystals_:
235
            raise ValueError(
236
                'Second argument passed to amd.compare() contains no '
237
                'valid crystals/periodic sets'
238
            )
239
        names_ = [s.name for s in crystals_]
240
        if verbose:
241
            container_ = tqdm.tqdm(crystals_, desc='Calculating', delay=1)
242
        else:
243
            container_ = crystals_
244
245
    if by == 'AMD':
246
        invs = [AMD(s, k) for s in container]
247
        if isinstance(container, tqdm.tqdm):
248
            container.close()
249
        compare_kwargs.pop('n_jobs', None)
250
        compare_kwargs.pop('backend', None)
251
        compare_kwargs.pop('verbose', None)
252
253
        if crystals_ is None:
254
            dm = AMD_pdist(invs, **compare_kwargs)
255
        else:
256
            invs_ = [AMD(s, k) for s in container_]
257
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
258
259
    elif by == 'PDD':
260
        invs = [PDD(s, k, collapse_tol=collapse_tol) for s in container]
261
        compare_kwargs.pop('low_memory', None)
262
263
        if crystals_ is None:
264
            dm = PDD_pdist(invs, **compare_kwargs)
265
        else:
266
            invs_ = [PDD(s, k, collapse_tol=collapse_tol) for s in container_]
267
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
268
269
    if nearest:
270
        nn_dm, inds = neighbours_from_distance_matrix(nearest, dm)
271
        data = {}
272
        for i in range(nearest):
273
            data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
274
            data['DIST ' + str(i+1)] = nn_dm[:, i]
275
        df = pd.DataFrame(data, index=names)
276
    else:
277
        if dm.ndim == 1:
278
            dm = squareform(dm)
279
        df = pd.DataFrame(dm, index=names, columns=names_)
280
281
    return df
282
283
284
def EMD(
285
        pdd: np.ndarray,
286
        pdd_: np.ndarray,
287
        metric: Optional[str] = 'chebyshev',
288
        return_transport: Optional[bool] = False,
289
        **kwargs
290
) -> Union[float, Tuple[float, np.ndarray]]:
291
    r"""Calculate the Earth mover's distance (EMD) between two PDDs, aka
292
    the Wasserstein metric.
293
294
    Parameters
295
    ----------
296
    pdd : :class:`numpy.ndarray`
297
        PDD of a crystal.
298
    pdd\_ : :class:`numpy.ndarray`
299
        PDD of a crystal.
300
    metric : str or callable, default 'chebyshev'
301
        EMD between PDDs requires defining a distance between PDD rows.
302
        By default, Chebyshev (L-infinity) distance is chosen like with
303
        AMDs. Accepts any metric accepted by
304
        :func:`scipy.spatial.distance.cdist`.
305
    return_transport: bool, default False
306
        Instead return a tuple ``(emd, transport_plan)`` where
307
        transport_plan describes the optimal flow.
308
309
    Returns
310
    -------
311
    emd : float
312
        Earth mover's distance between two PDDs. If ``return_transport``
313
        is True, return a tuple (emd, transport_plan).
314
315
    Raises
316
    ------
317
    ValueError
318
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
319
        columns.
320
    """
321
322
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
323
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
324
325
    if return_transport:
326
        return emd_dist, transport_plan
327
    return emd_dist
328
329
330
def AMD_cdist(
331
        amds,
332
        amds_,
333
        metric: str = 'chebyshev',
334
        low_memory: bool = False,
335
        **kwargs
336
) -> np.ndarray:
337
    r"""Compare two sets of AMDs with each other, returning a distance
338
    matrix. This function is essentially
339
    :func:`scipy.spatial.distance.cdist` with the default metric
340
    ``chebyshev`` and a low memory option.
341
342
    Parameters
343
    ----------
344
    amds : ArrayLike
345
        A list/array of AMDs.
346
    amds\_ : ArrayLike
347
        A list/array of AMDs.
348
    metric : str or callable, default 'chebyshev'
349
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
350
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
351
    low_memory : bool, default False
352
        Use a slower but more memory efficient method for large collections of
353
        AMDs (metric 'chebyshev' only).
354
355
    Returns
356
    -------
357
    dm : :class:`numpy.ndarray`
358
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` is the
359
        distance (given by ``metric``) between ``amds[i]`` and ``amds[j]``.
360
    """
361
362
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
363
364
    if len(amds.shape) == 1:
365
        amds = np.array([amds])
366
    if len(amds_.shape) == 1:
367
        amds_ = np.array([amds_])
368
369
    if low_memory:
370
        if metric != 'chebyshev':
371
            raise ValueError(
372
                "'low_memory' parameter of amd.AMD_cdist() only implemented "
373
                "with metric='chebyshev'."
374
            )
375
        dm = np.empty((len(amds), len(amds_)))
376
        for i, amd_vec in enumerate(amds):
377
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
378
    else:
379
        dm = cdist(amds, amds_, metric=metric, **kwargs)
380
381
    return dm
382
383
384
def AMD_pdist(
385
        amds,
386
        metric: str = 'chebyshev',
387
        low_memory: bool = False,
388
        **kwargs
389
) -> np.ndarray:
390
    """Compare a set of AMDs pairwise, returning a condensed distance
391
    matrix. This function is essentially
392
    :func:`scipy.spatial.distance.pdist` with the default metric
393
    ``chebyshev`` and a low memory parameter.
394
395
    Parameters
396
    ----------
397
    amds : ArrayLike
398
        An list/array of AMDs.
399
    metric : str or callable, default 'chebyshev'
400
        Usually AMDs are compared with the Chebyshev (L-infinity)
401
        distance. Accepts any metric accepted by
402
        :func:`scipy.spatial.distance.pdist`.
403
    low_memory : bool, default False
404
        Use a slower but more memory efficient method for large
405
        collections of AMDs (metric 'chebyshev' only).
406
407
    Returns
408
    -------
409
    cdm : :class:`numpy.ndarray`
410
        Returns a condensed distance matrix. Collapses a square distance
411
        matrix into a vector, just keeping the upper half. See the
412
        function :func:`squareform <scipy.spatial.distance.squareform>`
413
        from SciPy to convert to a symmetric square distance matrix.
414
    """
415
416
    amds = np.asarray(amds)
417
418
    if len(amds.shape) == 1:
419
        amds = np.array([amds])
420
421
    if low_memory:
422
        m = len(amds)
423
        if metric != 'chebyshev':
424
            raise ValueError(
425
                "'low_memory' parameter of amd.AMD_pdist() only implemented "
426
                "with metric='chebyshev'."
427
            )
428
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.float64)
429
        ind = 0
430
        for i in range(m):
431
            ind_ = ind + m - i - 1
432
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
433
            ind = ind_
434
    else:
435
        cdm = pdist(amds, metric=metric, **kwargs)
436
437
    return cdm
438
439
440
def PDD_cdist(
441
        pdds: List[np.ndarray],
442
        pdds_: List[np.ndarray],
443
        metric: str = 'chebyshev',
444
        backend: str = 'multiprocessing',
445
        n_jobs: Optional[int] = None,
446
        verbose: bool = False,
447
        **kwargs
448
) -> np.ndarray:
449
    r"""Compare two sets of PDDs with each other, returning a distance
450
    matrix. Supports parallel processing via joblib. If using
451
    parallelisation, make sure to include an if __name__ == '__main__'
452
    guard around this function.
453
454
    Parameters
455
    ----------
456
    pdds : List[:class:`numpy.ndarray`]
457
        A list of PDDs.
458
    pdds\_ : List[:class:`numpy.ndarray`]
459
        A list of PDDs.
460
    metric : str or callable, default 'chebyshev'
461
        Usually PDD rows are compared with the Chebyshev/l-infinity
462
        distance. Accepts any metric accepted by
463
        :func:`scipy.spatial.distance.cdist`.
464
    backend : str, default 'multiprocessing'
465
        The parallelization backend implementation. For a list of
466
        supported backends, see the backend argument of
467
        :class:`joblib.Parallel`.
468
    n_jobs : int, default None
469
        Maximum number of concurrent jobs for parallel processing with
470
        ``joblib``. Set to -1 to use the maximum. Using parallel
471
        processing may be slower for small inputs.
472
    verbose : bool, default False
473
        Prints a progress bar. If using parallel processing
474
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
475
        is used, otherwise uses tqdm.
476
477
    Returns
478
    -------
479
    dm : :class:`numpy.ndarray`
480
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
481
        :math:`ij` th entry is the distance between ``pdds[i]`` and
482
        ``pdds_[j]`` given by Earth mover's distance.
483
    """
484
485
    kwargs.pop('return_transport', None)
486
    k = pdds[0].shape[-1] - 1
487
    _verbose = 3 if verbose else 0
488
489
    if n_jobs is not None and n_jobs not in (0, 1):
490
        # TODO: put results into preallocated empty array in place
491
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
492
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
493
            for i in range(len(pdds)) for j in range(len(pdds_))
494
        )
495
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
496
497
    else:
498
        n, m = len(pdds), len(pdds_)
499
        dm = np.empty((n, m))
500
        if verbose:
501
            desc = f'Comparing {len(pdds)}x{len(pdds_)} PDDs (k={k})'
502
            progress_bar = tqdm.tqdm(desc=desc, total=n*m)
503
            for i in range(n):
504
                for j in range(m):
505
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
506
                    progress_bar.update(1)
507
            progress_bar.close()
508
        else:
509
            for i in range(n):
510
                for j in range(m):
511
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
512
513
    return dm
514
515
516
def PDD_pdist(
517
        pdds: List[np.ndarray],
518
        metric: str = 'chebyshev',
519
        backend: str = 'multiprocessing',
520
        n_jobs: Optional[int] = None,
521
        verbose: bool = False,
522
        **kwargs
523
) -> np.ndarray:
524
    """Compare a set of PDDs pairwise, returning a condensed distance
525
    matrix. Supports parallelisation via joblib. If using
526
    parallelisation, make sure to include a if __name__ == '__main__'
527
    guard around this function.
528
529
    Parameters
530
    ----------
531
    pdds : List[:class:`numpy.ndarray`]
532
        A list of PDDs.
533
    metric : str or callable, default 'chebyshev'
534
        Usually PDD rows are compared with the Chebyshev/l-infinity
535
        distance. Accepts any metric accepted by
536
        :func:`scipy.spatial.distance.cdist`.
537
    backend : str, default 'multiprocessing'
538
        The parallelization backend implementation. For a list of
539
        supported backends, see the backend argument of
540
        :class:`joblib.Parallel`.
541
    n_jobs : int, default None
542
        Maximum number of concurrent jobs for parallel processing with
543
        ``joblib``. Set to -1 to use the maximum. Using parallel
544
        processing may be slower for small inputs.
545
    verbose : bool, default False
546
        Prints a progress bar. If using parallel processing
547
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
548
        is used, otherwise uses tqdm.
549
550
    Returns
551
    -------
552
    cdm : :class:`numpy.ndarray`
553
        Returns a condensed distance matrix. Collapses a square distance
554
        matrix into a vector, just keeping the upper half. See the
555
        function :func:`squareform <scipy.spatial.distance.squareform>`
556
        from SciPy to convert to a symmetric square distance matrix.
557
    """
558
559
    kwargs.pop('return_transport', None)
560
    k = pdds[0].shape[-1] - 1
561
    _verbose = 3 if verbose else 0
562
563
    if n_jobs is not None and n_jobs > 1:
564
        # TODO: put results into preallocated empty array in place
565
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
566
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
567
            for i, j in combinations(range(len(pdds)), 2)
568
        )
569
        cdm = np.array(cdm)
570
571
    else:
572
        m = len(pdds)
573
        cdm_len = (m * (m - 1)) // 2
574
        cdm = np.empty(cdm_len, dtype=np.float64)
575
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
576
        if verbose:
577
            desc = f'Comparing {len(pdds)} PDDs pairwise (k={k})'
578
            progress_bar = tqdm.tqdm(desc=desc, total=cdm_len)
579
            for r, (i, j) in enumerate(inds):
580
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
581
                progress_bar.update(1)
582
            progress_bar.close()
583
        else:
584
            for r, (i, j) in enumerate(inds):
585
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
586
587
    return cdm
588
589
590
def emd(
591
        pdd: np.ndarray, pdd_: np.ndarray, **kwargs
592
) -> Union[float, Tuple[float, np.ndarray]]:
593
    """Alias for :func:`EMD() <.compare.EMD>`."""
594
    return EMD(pdd, pdd_, **kwargs)
595
596
597
def _unwrap_refcode_list(refcodes, **reader_kwargs):
598
    """Given string or list of strings, interpret as CSD refcodes and
599
    return a list of PeriodicSets.
600
    """
601
602
    reader_kwargs.pop('reader', None)
603
    if isinstance(refcodes, list):
604
        if not all(isinstance(refcode, str) for refcode in refcodes):
605
            raise TypeError(
606
                f'amd.compare(refcodes=True) expects a string or list of '
607
                'strings.'
608
            )
609
    elif not isinstance(refcodes, str):
610
        raise TypeError(
611
            f'amd.compare(refcodes=True) expects a string or list of '
612
            f'strings, got {refcodes.__class__.__name__}'
613
        )
614
    return list(CSDReader(refcodes, **reader_kwargs))
615
616
617
def _unwrap_pset_list(psets, **reader_kwargs):
618
    """Given a valid input for amd.compare(), return a list of
619
    PeriodicSets. Accepts paths, PeriodicSets, tuples or lists
620
    thereof."""
621
622
    def _extract_periodicsets(item, **reader_kwargs):
623
        """Given a path, PeriodicSet or tuple, return a list of the
624
        PeriodicSet(s)."""
625
626
        if isinstance(item, PeriodicSet):
627
            return [item]
628
        if isinstance(item, Tuple):
629
            return [PeriodicSet(item[0], item[1])]
630
        try:
631
            path = Path(item)
632
        except TypeError:
633
            raise ValueError(
634
                'amd.compare() expects a string, amd.PeriodicSet or tuple, '
635
                f'got {item.__class__.__name__}'
636
            )
637
        return list(CifReader(path, **reader_kwargs))
638
639
    reader_kwargs.pop('families', None)
640
    if isinstance(psets, list):
641
        return [s for i in psets
642
                for s in _extract_periodicsets(i, **reader_kwargs)]
643
    return _extract_periodicsets(psets, **reader_kwargs)
644