Passed
Push — master ( 0f0bce...406217 )
by Daniel
03:57
created

amd.compare.AMD_pdist()   A

Complexity

Conditions 5

Size

Total Lines 51
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 51
rs 9.0333
c 0
b 0
f 0
cc 5
nop 4

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Functions for comparing AMDs and PDDs of crystals.
2
"""
3
4
from typing import List, Optional, Union, Tuple
5
from functools import partial
6
from itertools import combinations
7
from pathlib import Path
8
9
import numpy as np
10
import pandas as pd
11
from scipy.spatial.distance import cdist, pdist, squareform
12
from joblib import Parallel, delayed
13
import tqdm
14
15
from .io import CifReader, CSDReader
16
from .calculate import AMD, PDD
17
from ._emd import network_simplex
18
from .periodicset import PeriodicSet
19
from .utils import neighbours_from_distance_matrix
20
21
__all__ = [
22
    'compare', 'EMD', 'AMD_cdist', 'AMD_pdist', 'PDD_cdist', 'PDD_pdist',
23
    'emd'
24
]
25
26
27
def compare(
28
        crystals,
29
        crystals_=None,
30
        by: str = 'AMD',
31
        k: int = 100,
32
        nearest: Optional[int] = None,
33
        reader: str = 'gemmi',
34
        remove_hydrogens: bool = False,
35
        disorder: str = 'skip',
36
        heaviest_component: bool = False,
37
        molecular_centres: bool = False,
38
        csd_refcodes: bool = False,
39
        refcode_families: bool = False,
40
        show_warnings: bool = True,
41
        collapse_tol: float = 1e-4,
42
        metric: str = 'chebyshev',
43
        n_jobs: Optional[int] = None,
44
        backend: str = 'multiprocessing',
45
        verbose: bool = False,
46
        low_memory: bool = False,
47
        **kwargs
48
) -> pd.DataFrame:
49
    r"""Given one or two sets of crystals, compare by AMD or PDD and
50
    return a pandas DataFrame of the distance matrix.
51
52
    Given one or two paths to CIFs, periodic sets, CSD refcodes or lists
53
    thereof, compare by AMD or PDD and return a pandas DataFrame of the
54
    distance matrix. Default is to comapre by AMD with k = 100. Accepts
55
    most keyword arguments accepted by
56
    :class:`CifReader <.io.CifReader>`,
57
    :class:`CSDReader <.io.CSDReader>` and functions from
58
    :mod:`.compare`.
59
60
    Parameters
61
    ----------
62
    crystals : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`
63
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
64
        or a list of those.
65
    crystals\_ : list of str or :class:`PeriodicSet <.periodicset.PeriodicSet>`, optional
66
        A path, :class:`PeriodicSet <.periodicset.PeriodicSet>`, tuple
67
        or a list of those.
68
    by : str, default 'AMD'
69
        Use AMD or PDD to compare crystals.
70
    k : int, default 100
71
        Parameter for AMD/PDD, the number of neighbour atoms to consider
72
        for each atom in a unit cell.
73
    nearest : int, deafult None
74
        Find a number of nearest neighbours instead of a full distance
75
        matrix between crystals.
76
    reader : str, optional
77
        The backend package used to parse the CIF. The default is
78
        :code:`gemmi`, :code:`pymatgen` and :code:`ase` are also
79
        accepted, as well as :code:`ccdc` if csd-python-api is
80
        installed. The ccdc reader should be able to read any format
81
        accepted by :class:`ccdc.io.EntryReader`, though only CIFs have
82
        been tested.
83
    remove_hydrogens : bool, optional
84
        Remove hydrogens from the crystals.
85
    disorder : str, optional
86
        Controls how disordered structures are handled. Default is
87
        ``skip`` which skips any crystal with disorder, since disorder
88
        conflicts with the periodic set model. To read disordered
89
        structures anyway, choose either :code:`ordered_sites` to remove
90
        atoms with disorder or :code:`all_sites` include all atoms
91
        regardless of disorder.
92
    heaviest_component : bool, optional, csd-python-api only
93
        Removes all but the heaviest molecule in
94
        the asymmeric unit, intended for removing solvents.
95
    molecular_centres : bool, default False, csd-python-api only
96
        Use the centres of molecules for comparison
97
        instead of centres of atoms.
98
    csd_refcodes : bool, optional, csd-python-api only
99
        Interpret ``crystals`` and ``crystals_`` as CSD refcodes or
100
        lists thereof, rather than paths.
101
    refcode_families : bool, optional, csd-python-api only
102
        Read all entries whose refcode starts with
103
        the given strings, or 'families' (e.g. giving 'DEBXIT' reads all
104
        entries with refcodes starting with DEBXIT).
105
    show_warnings : bool, optional
106
        Controls whether warnings that arise during reading are printed.
107
    collapse_tol: float, default 1e-4, ``by='PDD'`` only
108
        If two PDD rows have all elements closer
109
        than ``collapse_tol``, they are merged and weights are given to
110
        rows in proportion to the number of times they appeared.
111
    metric : str or callable, default 'chebyshev'
112
        The metric to compare AMDs/PDDs with. AMDs are compared directly
113
        with this metric. EMD is the metric used between PDDs, which
114
        requires giving a metric to use between PDD rows. Chebyshev
115
        (L-infinity) distance is the default. Accepts any metric
116
        accepted by :func:`scipy.spatial.distance.cdist`.
117
    n_jobs : int, default None, ``by='PDD'`` only
118
        Maximum number of concurrent jobs for
119
        parallel processing with :code:`joblib`. Set to -1 to use the
120
        maximum. Using parallel processing may be slower for small
121
        inputs.
122
    backend : str, default 'multiprocessing', ``by='PDD'`` only
123
        The parallelization backend implementation for PDD comparisons.
124
        For a list of supported backends, see the backend argument of
125
        :class:`joblib.Parallel`.
126
    verbose : bool, default False
127
        Prints a progress bar when reading crystals, calculating
128
        AMDs/PDDs and comparing PDDs. If using parallel processing
129
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
130
        is used, otherwise uses ``tqdm``.
131
    low_memory : bool, default False, ``by='AMD'`` only
132
        Use a slower but more memory efficient
133
        method for large collections of AMDs (metric 'chebyshev' only).
134
135
    Returns
136
    -------
137
    df : :class:`pandas.DataFrame`
138
        DataFrame of the distance matrix for the given crystals compared
139
        by the chosen invariant.
140
141
    Raises
142
    ------
143
    ValueError
144
        If by is not 'AMD' or 'PDD', if either set given have no valid
145
        crystals to compare, or if crystals or crystals\_ are an invalid
146
        type.
147
148
    Examples
149
    --------
150
    Compare everything in a .cif (deafult, AMD with k=100)::
151
152
        df = amd.compare('data.cif')
153
154
    Compare everything in one cif with all crystals in all cifs in a
155
    directory (PDD, k=50)::
156
157
        df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50)
158
159
    **Examples (csd-python-api only)**
160
161
    Compare two crystals by CSD refcode (PDD, k=50)::
162
163
        df = amd.compare('DEBXIT01', 'DEBXIT02', csd_refcodes=True, by='PDD', k=50)
164
165
    Compare everything in a refcode family (AMD, k=100)::
166
167
        df = amd.compare('DEBXIT', csd_refcodes=True, families=True)
168
    """
169
170
    by = by.upper()
171
    if by not in ('AMD', 'PDD'):
172
        raise ValueError(
173
            "'by' parameter of amd.compare() must be 'AMD' or 'PDD' (passed "
174
            f"'{by}')"
175
        )
176
177
    if heaviest_component or molecular_centres:
178
        reader = 'ccdc'
179
180
    if refcode_families:
181
        csd_refcodes = True
182
183
    reader_kwargs = {
184
        'reader': reader,
185
        'families': refcode_families,
186
        'remove_hydrogens': remove_hydrogens,
187
        'disorder': disorder,
188
        'heaviest_component': heaviest_component,
189
        'molecular_centres': molecular_centres,
190
        'show_warnings': show_warnings,
191
        'verbose': verbose,
192
    }
193
194
    compare_kwargs = {
195
        'metric': metric,
196
        'n_jobs': n_jobs,
197
        'backend': backend,
198
        'verbose': verbose,
199
        'low_memory': low_memory,
200
        **kwargs
201
    }
202
203
    # Get list(s) of periodic sets from first input
204
    if csd_refcodes:
205
        crystals = _unwrap_refcode_list(crystals, **reader_kwargs)
206
    else:
207
        crystals = _unwrap_pset_list(crystals, **reader_kwargs)
208
209
    if not crystals:
210
        raise ValueError(
211
            'First argument passed to amd.compare() contains no valid '
212
            'crystals/periodic sets'
213
        )
214
    names = [s.name for s in crystals]
215
    if verbose:
216
        container = tqdm.tqdm(crystals, desc='Calculating', delay=1)
217
    else:
218
        container = crystals
219
220
    # Get list(s) of periodic sets from second input if given
221
    if crystals_ is None:
222
        names_ = names
223
        container_ = None
224
    else:
225
        if csd_refcodes:
226
            crystals_ = _unwrap_refcode_list(crystals_, **reader_kwargs)
227
        else:
228
            crystals_ = _unwrap_pset_list(crystals_, **reader_kwargs)
229
        if not crystals_:
230
            raise ValueError(
231
                'Second argument passed to amd.compare() contains no '
232
                'valid crystals/periodic sets'
233
            )
234
        names_ = [s.name for s in crystals_]
235
        if verbose:
236
            container_ = tqdm.tqdm(crystals_, desc='Calculating', delay=1)
237
        else:
238
            container_ = crystals_
239
240
    if by == 'AMD':
241
        invs = [AMD(s, k) for s in container]
242
        if isinstance(container, tqdm.tqdm):
243
            container.close()
244
        compare_kwargs.pop('n_jobs', None)
245
        compare_kwargs.pop('backend', None)
246
        compare_kwargs.pop('verbose', None)
247
248
        if crystals_ is None:
249
            dm = AMD_pdist(invs, **compare_kwargs)
250
        else:
251
            invs_ = [AMD(s, k) for s in container_]
252
            dm = AMD_cdist(invs, invs_, **compare_kwargs)
253
254
    elif by == 'PDD':
255
        invs = [PDD(s, k, collapse_tol=collapse_tol) for s in container]
256
        compare_kwargs.pop('low_memory', None)
257
258
        if crystals_ is None:
259
            dm = PDD_pdist(invs, **compare_kwargs)
260
        else:
261
            invs_ = [PDD(s, k, collapse_tol=collapse_tol) for s in container_]
262
            dm = PDD_cdist(invs, invs_, **compare_kwargs)
263
264
    if nearest:
265
        nn_dm, inds = neighbours_from_distance_matrix(nearest, dm)
0 ignored issues
show
introduced by
The variable dm does not seem to be defined for all execution paths.
Loading history...
266
        data = {}
267
        for i in range(nearest):
268
            data['ID ' + str(i+1)] = [names_[j] for j in inds[:, i]]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable str does not seem to be defined.
Loading history...
269
            data['DIST ' + str(i+1)] = nn_dm[:, i]
270
        df = pd.DataFrame(data, index=names)
271
    else:
272
        if dm.ndim == 1:
273
            dm = squareform(dm)
274
        df = pd.DataFrame(dm, index=names, columns=names_)
275
276
    return df
277
278
279
def EMD(
280
        pdd: np.ndarray,
281
        pdd_: np.ndarray,
282
        metric: Optional[str] = 'chebyshev',
283
        return_transport: Optional[bool] = False,
284
        **kwargs
285
) -> Union[float, Tuple[float, np.ndarray]]:
286
    r"""Calculate the Earth mover's distance (EMD) between two PDDs, aka
287
    the Wasserstein metric.
288
289
    Parameters
290
    ----------
291
    pdd : :class:`numpy.ndarray`
292
        PDD of a crystal.
293
    pdd\_ : :class:`numpy.ndarray`
294
        PDD of a crystal.
295
    metric : str or callable, default 'chebyshev'
296
        EMD between PDDs requires defining a distance between PDD rows.
297
        By default, Chebyshev (L-infinity) distance is chosen like with
298
        AMDs. Accepts any metric accepted by
299
        :func:`scipy.spatial.distance.cdist`.
300
    return_transport: bool, default False
301
        Instead return a tuple ``(emd, transport_plan)`` where
302
        transport_plan describes the optimal flow.
303
304
    Returns
305
    -------
306
    emd : float
307
        Earth mover's distance between two PDDs. If ``return_transport``
308
        is True, return a tuple (emd, transport_plan).
309
310
    Raises
311
    ------
312
    ValueError
313
        Thrown if ``pdd`` and ``pdd_`` do not have the same number of
314
        columns.
315
    """
316
317
    dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs)
318
    emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm)
319
320
    if return_transport:
321
        return emd_dist, transport_plan
322
    return emd_dist
323
324
325
def AMD_cdist(
326
        amds,
327
        amds_,
328
        metric: str = 'chebyshev',
329
        low_memory: bool = False,
330
        **kwargs
331
) -> np.ndarray:
332
    r"""Compare two sets of AMDs with each other, returning a distance
333
    matrix. This function is essentially
334
    :func:`scipy.spatial.distance.cdist` with the default metric
335
    ``chebyshev`` and a low memory option.
336
337
    Parameters
338
    ----------
339
    amds : ArrayLike
340
        A list/array of AMDs.
341
    amds\_ : ArrayLike
342
        A list/array of AMDs.
343
    metric : str or callable, default 'chebyshev'
344
        Usually AMDs are compared with the Chebyshev (L-infinitys) distance.
345
        Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`.
346
    low_memory : bool, default False
347
        Use a slower but more memory efficient method for large collections of
348
        AMDs (metric 'chebyshev' only).
349
350
    Returns
351
    -------
352
    dm : :class:`numpy.ndarray`
353
        A distance matrix shape ``(len(amds), len(amds_))``. ``dm[ij]`` is the
354
        distance (given by ``metric``) between ``amds[i]`` and ``amds[j]``.
355
    """
356
357
    amds, amds_ = np.asarray(amds), np.asarray(amds_)
358
359
    if len(amds.shape) == 1:
360
        amds = np.array([amds])
361
    if len(amds_.shape) == 1:
362
        amds_ = np.array([amds_])
363
364
    if low_memory:
365
        if metric != 'chebyshev':
366
            raise ValueError(
367
                "'low_memory' parameter of amd.AMD_cdist() only implemented "
368
                "with metric='chebyshev'."
369
            )
370
        dm = np.empty((len(amds), len(amds_)))
371
        for i, amd_vec in enumerate(amds):
372
            dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1)
373
    else:
374
        dm = cdist(amds, amds_, metric=metric, **kwargs)
375
376
    return dm
377
378
379
def AMD_pdist(
380
        amds, metric: str = 'chebyshev', low_memory: bool = False, **kwargs
381
) -> np.ndarray:
382
    """Compare a set of AMDs pairwise, returning a condensed distance
383
    matrix. This function is essentially
384
    :func:`scipy.spatial.distance.pdist` with the default metric
385
    ``chebyshev`` and a low memory parameter.
386
387
    Parameters
388
    ----------
389
    amds : ArrayLike
390
        An list/array of AMDs.
391
    metric : str or callable, default 'chebyshev'
392
        Usually AMDs are compared with the Chebyshev (L-infinity)
393
        distance. Accepts any metric accepted by
394
        :func:`scipy.spatial.distance.pdist`.
395
    low_memory : bool, default False
396
        Use a slower but more memory efficient method for large
397
        collections of AMDs (metric 'chebyshev' only).
398
399
    Returns
400
    -------
401
    cdm : :class:`numpy.ndarray`
402
        Returns a condensed distance matrix. Collapses a square distance
403
        matrix into a vector, just keeping the upper half. See the
404
        function :func:`squareform <scipy.spatial.distance.squareform>`
405
        from SciPy to convert to a symmetric square distance matrix.
406
    """
407
408
    amds = np.asarray(amds)
409
410
    if len(amds.shape) == 1:
411
        amds = np.array([amds])
412
413
    if low_memory:
414
        m = len(amds)
415
        if metric != 'chebyshev':
416
            raise ValueError(
417
                "'low_memory' parameter of amd.AMD_pdist() only implemented "
418
                "with metric='chebyshev'."
419
            )
420
        cdm = np.empty((m * (m - 1)) // 2, dtype=np.float64)
421
        ind = 0
422
        for i in range(m):
423
            ind_ = ind + m - i - 1
424
            cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1)
425
            ind = ind_
426
    else:
427
        cdm = pdist(amds, metric=metric, **kwargs)
428
429
    return cdm
430
431
432
def PDD_cdist(
433
        pdds: List[np.ndarray],
434
        pdds_: List[np.ndarray],
435
        metric: str = 'chebyshev',
436
        backend: str = 'multiprocessing',
437
        n_jobs: Optional[int] = None,
438
        verbose: bool = False,
439
        **kwargs
440
) -> np.ndarray:
441
    r"""Compare two sets of PDDs with each other, returning a distance
442
    matrix. Supports parallel processing via joblib. If using
443
    parallelisation, make sure to include an if __name__ == '__main__'
444
    guard around this function.
445
446
    Parameters
447
    ----------
448
    pdds : List[:class:`numpy.ndarray`]
449
        A list of PDDs.
450
    pdds\_ : List[:class:`numpy.ndarray`]
451
        A list of PDDs.
452
    metric : str or callable, default 'chebyshev'
453
        Usually PDD rows are compared with the Chebyshev/l-infinity
454
        distance. Accepts any metric accepted by
455
        :func:`scipy.spatial.distance.cdist`.
456
    backend : str, default 'multiprocessing'
457
        The parallelization backend implementation. For a list of
458
        supported backends, see the backend argument of
459
        :class:`joblib.Parallel`.
460
    n_jobs : int, default None
461
        Maximum number of concurrent jobs for parallel processing with
462
        ``joblib``. Set to -1 to use the maximum. Using parallel
463
        processing may be slower for small inputs.
464
    verbose : bool, default False
465
        Prints a progress bar. If using parallel processing
466
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
467
        is used, otherwise uses tqdm.
468
469
    Returns
470
    -------
471
    dm : :class:`numpy.ndarray`
472
        Returns a distance matrix shape ``(len(pdds), len(pdds_))``. The
473
        :math:`ij` th entry is the distance between ``pdds[i]`` and
474
        ``pdds_[j]`` given by Earth mover's distance.
475
    """
476
477
    kwargs.pop('return_transport', None)
478
    k = pdds[0].shape[-1] - 1
479
    _verbose = 3 if verbose else 0
480
481
    if n_jobs is not None and n_jobs not in (0, 1):
482
        # TODO: put results into preallocated empty array in place
483
        dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
484
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j])
485
            for i in range(len(pdds)) for j in range(len(pdds_))
486
        )
487
        dm = np.array(dm).reshape((len(pdds), len(pdds_)))
488
489
    else:
490
        n, m = len(pdds), len(pdds_)
491
        dm = np.empty((n, m))
492
        if verbose:
493
            desc = f'Comparing {len(pdds)}x{len(pdds_)} PDDs (k={k})'
494
            progress_bar = tqdm.tqdm(desc=desc, total=n*m)
495
            for i in range(n):
496
                for j in range(m):
497
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
498
                    progress_bar.update(1)
499
            progress_bar.close()
500
        else:
501
            for i in range(n):
502
                for j in range(m):
503
                    dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs)
504
505
    return dm
506
507
508
def PDD_pdist(
509
        pdds: List[np.ndarray],
510
        metric: str = 'chebyshev',
511
        backend: str = 'multiprocessing',
512
        n_jobs: Optional[int] = None,
513
        verbose: bool = False,
514
        **kwargs
515
) -> np.ndarray:
516
    """Compare a set of PDDs pairwise, returning a condensed distance
517
    matrix. Supports parallelisation via joblib. If using
518
    parallelisation, make sure to include a if __name__ == '__main__'
519
    guard around this function.
520
521
    Parameters
522
    ----------
523
    pdds : List[:class:`numpy.ndarray`]
524
        A list of PDDs.
525
    metric : str or callable, default 'chebyshev'
526
        Usually PDD rows are compared with the Chebyshev/l-infinity
527
        distance. Accepts any metric accepted by
528
        :func:`scipy.spatial.distance.cdist`.
529
    backend : str, default 'multiprocessing'
530
        The parallelization backend implementation. For a list of
531
        supported backends, see the backend argument of
532
        :class:`joblib.Parallel`.
533
    n_jobs : int, default None
534
        Maximum number of concurrent jobs for parallel processing with
535
        ``joblib``. Set to -1 to use the maximum. Using parallel
536
        processing may be slower for small inputs.
537
    verbose : bool, default False
538
        Prints a progress bar. If using parallel processing
539
        (n_jobs > 1), the verbose argument of :class:`joblib.Parallel`
540
        is used, otherwise uses tqdm.
541
542
    Returns
543
    -------
544
    cdm : :class:`numpy.ndarray`
545
        Returns a condensed distance matrix. Collapses a square distance
546
        matrix into a vector, just keeping the upper half. See the
547
        function :func:`squareform <scipy.spatial.distance.squareform>`
548
        from SciPy to convert to a symmetric square distance matrix.
549
    """
550
551
    kwargs.pop('return_transport', None)
552
    k = pdds[0].shape[-1] - 1
553
    _verbose = 3 if verbose else 0
554
555
    if n_jobs is not None and n_jobs > 1:
556
        # TODO: put results into preallocated empty array in place
557
        cdm = Parallel(backend=backend, n_jobs=n_jobs, verbose=_verbose)(
558
            delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j])
559
            for i, j in combinations(range(len(pdds)), 2)
560
        )
561
        cdm = np.array(cdm)
562
563
    else:
564
        m = len(pdds)
565
        cdm_len = (m * (m - 1)) // 2
566
        cdm = np.empty(cdm_len, dtype=np.float64)
567
        inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable j does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable i does not seem to be defined.
Loading history...
568
        if verbose:
569
            desc = f'Comparing {len(pdds)} PDDs pairwise (k={k})'
570
            progress_bar = tqdm.tqdm(desc=desc, total=cdm_len)
571
            for r, (i, j) in enumerate(inds):
572
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
573
                progress_bar.update(1)
574
            progress_bar.close()
575
        else:
576
            for r, (i, j) in enumerate(inds):
577
                cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs)
578
579
    return cdm
580
581
582
def emd(
583
        pdd: np.ndarray, pdd_: np.ndarray, **kwargs
584
) -> Union[float, Tuple[float, np.ndarray]]:
585
    """Alias for :func:`EMD() <.compare.EMD>`."""
586
    return EMD(pdd, pdd_, **kwargs)
587
588
589
def _unwrap_refcode_list(refcodes, **reader_kwargs):
590
    """Given string or list of strings, interpret as CSD refcodes and
591
    return a list of PeriodicSets.
592
    """
593
594
    reader_kwargs.pop('reader', None)
595
    if isinstance(refcodes, list):
596
        if not all(isinstance(refcode, str) for refcode in refcodes):
597
            raise TypeError(
598
                f'amd.compare(refcodes=True) expects a string or list of '
599
                'strings.'
600
            )
601
    elif not isinstance(refcodes, str):
602
        raise TypeError(
603
            f'amd.compare(refcodes=True) expects a string or list of '
604
            f'strings, got {refcodes.__class__.__name__}'
605
        )
606
    return list(CSDReader(refcodes, **reader_kwargs))
607
608
609
def _unwrap_pset_list(psets, **reader_kwargs):
610
    """Given a valid input for amd.compare(), return a list of
611
    PeriodicSets. Accepts paths, PeriodicSets, tuples or lists
612
    thereof."""
613
614
    def _extract_periodicsets(item, **reader_kwargs):
615
        """Given a path, PeriodicSet or tuple, return a list of the
616
        PeriodicSet(s)."""
617
618
        if isinstance(item, PeriodicSet):
619
            return [item]
620
        if isinstance(item, Tuple):
621
            return [PeriodicSet(item[0], item[1])]
622
        try:
623
            path = Path(item)
624
        except TypeError:
625
            raise ValueError(
626
                'amd.compare() expects a string, amd.PeriodicSet or tuple, '
627
                f'got {item.__class__.__name__}'
628
            )
629
        return list(CifReader(path, **reader_kwargs))
630
631
    reader_kwargs.pop('families', None)
632
    if isinstance(psets, list):
633
        return [s for i in psets
634
                for s in _extract_periodicsets(i, **reader_kwargs)]
635
    return _extract_periodicsets(psets, **reader_kwargs)
636