1
|
|
|
"""Functions for comparing AMDs and PDDs of crystals. |
2
|
|
|
""" |
3
|
|
|
|
4
|
|
|
import warnings |
5
|
|
|
from typing import List, Optional, Union |
6
|
|
|
from functools import partial |
7
|
|
|
from itertools import combinations |
8
|
|
|
import os |
9
|
|
|
|
10
|
|
|
import numpy as np |
11
|
|
|
import pandas as pd |
12
|
|
|
from scipy.spatial.distance import cdist, pdist, squareform |
13
|
|
|
from joblib import Parallel, delayed |
14
|
|
|
from progressbar import ProgressBar |
15
|
|
|
|
16
|
|
|
from .io import CifReader, CSDReader |
17
|
|
|
from .calculate import AMD, PDD |
18
|
|
|
from ._emd import network_simplex |
19
|
|
|
from .periodicset import PeriodicSet |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
def compare( |
23
|
|
|
crystals, |
24
|
|
|
crystals_=None, |
25
|
|
|
by='AMD', |
26
|
|
|
k=100, |
27
|
|
|
**kwargs |
28
|
|
|
) -> pd.DataFrame: |
29
|
|
|
r"""Given one or two sets of periodic set(s), refcode(s) or cif(s), compare them |
30
|
|
|
returning a DataFrame of the distance matrix. Default is to comapre by PDD |
31
|
|
|
with k=100. Accepts most keyword arguments accepted by the CifReader, CSDReader |
32
|
|
|
and compare functions, for a full list see the documentation Quick Start page. |
33
|
|
|
Note that using refcodes requires csd-python-api. |
34
|
|
|
|
35
|
|
|
Parameters |
36
|
|
|
---------- |
37
|
|
|
crystals : array or list of arrays |
38
|
|
|
One or a collection of paths, refcodes, file objects or :class:`.periodicset.PeriodicSet` s. |
39
|
|
|
crystals\_ : array or list of arrays, optional |
40
|
|
|
One or a collection of paths, refcodes, file objects or :class:`.periodicset.PeriodicSet` s. |
41
|
|
|
by : str, default 'AMD' |
42
|
|
|
Invariant to compare by, either 'AMD' or 'PDD'. |
43
|
|
|
k : int, default 100 |
44
|
|
|
k value to use for the invariants (length of AMD, or number of columns in PDD). |
45
|
|
|
|
46
|
|
|
Returns |
47
|
|
|
------- |
48
|
|
|
df : pandas.DataFrame |
49
|
|
|
DataFrame of the distance matrix for the given crystals compared by the chosen invariant. |
50
|
|
|
|
51
|
|
|
Raises |
52
|
|
|
------ |
53
|
|
|
ValueError |
54
|
|
|
If by is not 'AMD' or 'PDD', if either set given have no valid crystals |
55
|
|
|
to compare, or if crystals or crystals\_ are an invalid type. |
56
|
|
|
|
57
|
|
|
Examples |
58
|
|
|
-------- |
59
|
|
|
Compare everything in a .cif (deafult, AMD with k=100):: |
60
|
|
|
|
61
|
|
|
df = amd.compare('data.cif') |
62
|
|
|
|
63
|
|
|
Compare everything in one cif with all crystals in all cifs in a directory (PDD, k=50):: |
64
|
|
|
|
65
|
|
|
df = amd.compare('data.cif', 'dir/to/cifs', by='PDD', k=50) |
66
|
|
|
|
67
|
|
|
**Examples (csd-python-api only)** |
68
|
|
|
|
69
|
|
|
Compare two crystals by CSD refcode (PDD, k=50):: |
70
|
|
|
|
71
|
|
|
df = amd.compare('DEBXIT01', 'DEBXIT02', by='PDD', k=50) |
72
|
|
|
|
73
|
|
|
Compare everything in a refcode family (AMD, k=100):: |
74
|
|
|
|
75
|
|
|
df = amd.compare('DEBXIT', families=True) |
76
|
|
|
""" |
77
|
|
|
|
78
|
|
|
by = by.upper() |
79
|
|
|
if by not in ('AMD', 'PDD'): |
80
|
|
|
raise ValueError(f"parameter 'by' in compare accepts 'AMD' or 'PDD', was passed {by}") |
81
|
|
|
|
82
|
|
|
reader_kwargs = { |
83
|
|
|
'reader': 'ase', |
84
|
|
|
'families': False, |
85
|
|
|
'remove_hydrogens': False, |
86
|
|
|
'disorder': 'skip', |
87
|
|
|
'heaviest_component': False, |
88
|
|
|
'molecular_centres': False, |
89
|
|
|
'show_warnings': True, |
90
|
|
|
} |
91
|
|
|
|
92
|
|
|
calc_kwargs = { |
93
|
|
|
'collapse': True, |
94
|
|
|
'collapse_tol': 1e-4, |
95
|
|
|
'lexsort': False, |
96
|
|
|
} |
97
|
|
|
|
98
|
|
|
compare_kwargs = { |
99
|
|
|
'metric': 'chebyshev', |
100
|
|
|
'n_jobs': None, |
101
|
|
|
'verbose': 0, |
102
|
|
|
'low_memory': False, |
103
|
|
|
} |
104
|
|
|
|
105
|
|
|
for default_kwargs in (reader_kwargs, calc_kwargs, compare_kwargs): |
106
|
|
|
for key in kwargs.keys() & default_kwargs.keys(): |
107
|
|
|
default_kwargs[key] = kwargs[key] |
108
|
|
|
|
109
|
|
|
crystals = _unwrap_periodicset_list(crystals, **reader_kwargs) |
110
|
|
|
if not crystals: |
111
|
|
|
raise ValueError('No valid crystals to compare in first set.') |
112
|
|
|
names = [s.name for s in crystals] |
113
|
|
|
|
114
|
|
|
if crystals_ is None: |
115
|
|
|
names_ = names |
116
|
|
|
else: |
117
|
|
|
crystals_ = _unwrap_periodicset_list(crystals_, **reader_kwargs) |
118
|
|
|
if not crystals_: |
119
|
|
|
raise ValueError('No valid crystals to compare in second set.') |
120
|
|
|
names_ = [s.name for s in crystals_] |
121
|
|
|
|
122
|
|
|
if reader_kwargs['molecular_centres']: |
123
|
|
|
crystals = [(c.molecular_centres, c.cell) for c in crystals] |
124
|
|
|
if crystals_: |
125
|
|
|
crystals_ = [(c.molecular_centres, c.cell) for c in crystals_] |
126
|
|
|
|
127
|
|
|
if by == 'AMD': |
128
|
|
|
|
129
|
|
|
invs = [AMD(s, k) for s in crystals] |
130
|
|
|
compare_kwargs.pop('n_jobs', None) |
131
|
|
|
compare_kwargs.pop('verbose', None) |
132
|
|
|
|
133
|
|
|
if crystals_ is None: |
134
|
|
|
dm = squareform(AMD_pdist(invs, **compare_kwargs)) |
135
|
|
|
else: |
136
|
|
|
invs_ = [AMD(s, k) for s in crystals_] |
137
|
|
|
dm = AMD_cdist(invs, invs_, **compare_kwargs) |
138
|
|
|
|
139
|
|
|
elif by == 'PDD': |
140
|
|
|
|
141
|
|
|
invs = [PDD(s, k, **calc_kwargs) for s in crystals] |
142
|
|
|
compare_kwargs.pop('low_memory', None) |
143
|
|
|
|
144
|
|
|
if crystals_ is None: |
145
|
|
|
dm = squareform(PDD_pdist(invs, **compare_kwargs)) |
146
|
|
|
else: |
147
|
|
|
invs_ = [PDD(s, k) for s in crystals_] |
148
|
|
|
dm = PDD_cdist(invs, invs_, **compare_kwargs) |
149
|
|
|
|
150
|
|
|
return pd.DataFrame(dm, index=names, columns=names_) |
151
|
|
|
|
152
|
|
|
|
153
|
|
|
def EMD( |
154
|
|
|
pdd: np.ndarray, |
155
|
|
|
pdd_: np.ndarray, |
156
|
|
|
metric: Optional[str] = 'chebyshev', |
157
|
|
|
return_transport: Optional[bool] = False, |
158
|
|
|
**kwargs): |
159
|
|
|
r"""Earth mover's distance (EMD) between two PDDs, also known as |
160
|
|
|
the Wasserstein metric. |
161
|
|
|
|
162
|
|
|
Parameters |
163
|
|
|
---------- |
164
|
|
|
pdd : numpy.ndarray |
165
|
|
|
PDD of a crystal. |
166
|
|
|
pdd\_ : numpy.ndarray |
167
|
|
|
PDD of a crystal. |
168
|
|
|
metric : str or callable, default 'chebyshev' |
169
|
|
|
EMD between PDDs requires defining a distance between PDD rows. |
170
|
|
|
By default, Chebyshev (L-infinity) distance is chosen as with AMDs. |
171
|
|
|
Accepts any metric accepted by :func:`scipy.spatial.distance.cdist`. |
172
|
|
|
return_transport: bool, default False |
173
|
|
|
Return a tuple ``(distance, transport_plan)`` with the optimal transport. |
174
|
|
|
|
175
|
|
|
Returns |
176
|
|
|
------- |
177
|
|
|
emd : float |
178
|
|
|
Earth mover's distance between two PDDs. |
179
|
|
|
|
180
|
|
|
Raises |
181
|
|
|
------ |
182
|
|
|
ValueError |
183
|
|
|
Thrown if ``pdd`` and ``pdd_`` do not have the same number of |
|
|
|
|
184
|
|
|
columns (``k`` value). |
185
|
|
|
""" |
186
|
|
|
|
187
|
|
|
dm = cdist(pdd[:, 1:], pdd_[:, 1:], metric=metric, **kwargs) |
188
|
|
|
emd_dist, transport_plan = network_simplex(pdd[:, 0], pdd_[:, 0], dm) |
189
|
|
|
|
190
|
|
|
if return_transport: |
191
|
|
|
return emd_dist, transport_plan |
192
|
|
|
|
193
|
|
|
return emd_dist |
194
|
|
|
|
195
|
|
|
|
196
|
|
|
def AMD_cdist( |
197
|
|
|
amds: Union[np.ndarray, List[np.ndarray]], |
198
|
|
|
amds_: Union[np.ndarray, List[np.ndarray]], |
199
|
|
|
metric: str = 'chebyshev', |
200
|
|
|
low_memory: bool = False, |
201
|
|
|
**kwargs |
202
|
|
|
) -> np.ndarray: |
203
|
|
|
r"""Compare two sets of AMDs with each other, returning a distance matrix. |
204
|
|
|
This function is essentially identical to :func:`scipy.spatial.distance.cdist` |
205
|
|
|
with the default metric ``chebyshev``. |
206
|
|
|
|
207
|
|
|
Parameters |
208
|
|
|
---------- |
209
|
|
|
amds : array_like |
210
|
|
|
A list of AMDs. |
211
|
|
|
amds\_ : array_like |
212
|
|
|
A list of AMDs. |
213
|
|
|
metric : str or callable, default 'chebyshev' |
214
|
|
|
Usually AMDs are compared with the Chebyshev (L-infinitys) distance. |
215
|
|
|
Can take any metric accepted by :func:`scipy.spatial.distance.cdist`. |
216
|
|
|
low_memory : bool, default False |
217
|
|
|
Use a slower but more memory efficient method for |
218
|
|
|
large collections of AMDs (Chebyshev metric only). |
219
|
|
|
|
220
|
|
|
Returns |
221
|
|
|
------- |
222
|
|
|
dm : numpy.ndarray |
223
|
|
|
A distance matrix shape ``(len(amds), len(amds_))``. |
224
|
|
|
``dm[ij]`` is the distance (given by ``metric``) |
|
|
|
|
225
|
|
|
between ``amds[i]`` and ``amds[j]``. |
226
|
|
|
""" |
227
|
|
|
|
228
|
|
|
amds, amds_ = np.asarray(amds), np.asarray(amds_) |
229
|
|
|
|
230
|
|
|
if len(amds.shape) == 1: |
231
|
|
|
amds = np.array([amds]) |
232
|
|
|
if len(amds_.shape) == 1: |
233
|
|
|
amds_ = np.array([amds_]) |
234
|
|
|
|
235
|
|
|
if low_memory: |
236
|
|
|
if metric != 'chebyshev': |
237
|
|
|
warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning) |
238
|
|
|
|
239
|
|
|
dm = np.empty((len(amds), len(amds_))) |
240
|
|
|
for i, amd_vec in enumerate(amds): |
241
|
|
|
dm[i] = np.amax(np.abs(amds_ - amd_vec), axis=-1) |
242
|
|
|
else: |
243
|
|
|
dm = cdist(amds, amds_, metric=metric, **kwargs) |
244
|
|
|
|
245
|
|
|
return dm |
246
|
|
|
|
247
|
|
|
|
248
|
|
|
def AMD_pdist( |
249
|
|
|
amds: Union[np.ndarray, List[np.ndarray]], |
250
|
|
|
metric: str = 'chebyshev', |
251
|
|
|
low_memory: bool = False, |
252
|
|
|
**kwargs |
253
|
|
|
) -> np.ndarray: |
254
|
|
|
"""Compare a set of AMDs pairwise, returning a condensed distance matrix. |
255
|
|
|
This function is essentially identical to :func:`scipy.spatial.distance.pdist` |
256
|
|
|
with the default metric ``chebyshev``. |
257
|
|
|
|
258
|
|
|
Parameters |
259
|
|
|
---------- |
260
|
|
|
amds : array_like |
261
|
|
|
An array/list of AMDs. |
262
|
|
|
metric : str or callable, default 'chebyshev' |
263
|
|
|
Usually AMDs are compared with the Chebyshev (L-infinity) distance. |
264
|
|
|
Can take any metric accepted by :func:`scipy.spatial.distance.pdist`. |
265
|
|
|
low_memory : bool, default False |
266
|
|
|
Optionally use a slightly slower but more memory efficient method for |
267
|
|
|
large collections of AMDs (Chebyshev metric only). |
268
|
|
|
|
269
|
|
|
Returns |
270
|
|
|
------- |
271
|
|
|
numpy.ndarray |
272
|
|
|
Returns a condensed distance matrix. Collapses a square distance |
273
|
|
|
matrix into a vector, just keeping the upper half. See |
274
|
|
|
:func:`scipy.spatial.distance.squareform` to convert to a square |
|
|
|
|
275
|
|
|
distance matrix or for more on condensed distance matrices. |
276
|
|
|
""" |
277
|
|
|
|
278
|
|
|
amds = np.asarray(amds) |
279
|
|
|
|
280
|
|
|
if len(amds.shape) == 1: |
281
|
|
|
amds = np.array([amds]) |
282
|
|
|
|
283
|
|
|
if low_memory: |
284
|
|
|
m = len(amds) |
285
|
|
|
if metric != 'chebyshev': |
286
|
|
|
warnings.warn("Using only allowed metric 'chebyshev' for low_memory", UserWarning) |
287
|
|
|
cdm = np.empty((m * (m - 1)) // 2, dtype=np.double) |
288
|
|
|
ind = 0 |
289
|
|
|
for i in range(m): |
290
|
|
|
ind_ = ind + m - i - 1 |
291
|
|
|
cdm[ind:ind_] = np.amax(np.abs(amds[i+1:] - amds[i]), axis=-1) |
292
|
|
|
ind = ind_ |
293
|
|
|
else: |
294
|
|
|
cdm = pdist(amds, metric=metric, **kwargs) |
295
|
|
|
|
296
|
|
|
return cdm |
297
|
|
|
|
298
|
|
|
|
299
|
|
|
def PDD_cdist( |
|
|
|
|
300
|
|
|
pdds: List[np.ndarray], |
301
|
|
|
pdds_: List[np.ndarray], |
302
|
|
|
metric: str = 'chebyshev', |
303
|
|
|
backend='multiprocessing', |
304
|
|
|
n_jobs=None, |
305
|
|
|
verbose=0, |
306
|
|
|
**kwargs |
307
|
|
|
) -> np.ndarray: |
308
|
|
|
r"""Compare two sets of PDDs with each other, returning a distance matrix. |
309
|
|
|
Supports parallel processing via joblib. If using parallelisation, make sure to |
310
|
|
|
include a if __name__ == '__main__' guard around this function. |
311
|
|
|
|
312
|
|
|
Parameters |
313
|
|
|
---------- |
314
|
|
|
pdds : List[numpy.ndarray] |
315
|
|
|
A list of PDDs. |
316
|
|
|
pdds\_ : List[numpy.ndarray] |
317
|
|
|
A list of PDDs. |
318
|
|
|
metric : str or callable, default 'chebyshev' |
319
|
|
|
Usually PDD rows are compared with the Chebyshev/l-infinity distance. |
320
|
|
|
Can take any metric accepted by :func:`scipy.spatial.distance.cdist`. |
321
|
|
|
n_jobs : int, default None |
322
|
|
|
Maximum number of concurrent jobs for parallel processing with joblib. |
323
|
|
|
Set to -1 to use the maximum possible. Note that for small inputs (< 100), |
|
|
|
|
324
|
|
|
using parallel processing may be slower than the default n_jobs=None. |
325
|
|
|
verbose : int, default 0 |
326
|
|
|
Controls verbosity. If using parallel processing (n_jobs > 1), verbose is |
327
|
|
|
passed to :class:`joblib.Parallel`, where larger values = more verbosity. |
328
|
|
|
Otherwise, uses progressbar2 where the progressbar is either on or off. |
329
|
|
|
backend : str, default 'multiprocessing' |
330
|
|
|
Specifies the parallelization backend implementation. For a list of |
331
|
|
|
supported backends, see the backend argument of :class:`joblib.Parallel`. |
332
|
|
|
|
333
|
|
|
Returns |
334
|
|
|
------- |
335
|
|
|
numpy.ndarray |
336
|
|
|
Returns a distance matrix shape ``(len(pdds), len(pdds_))``. |
337
|
|
|
The :math:`ij` th entry is the distance between ``pdds[i]`` |
338
|
|
|
and ``pdds_[j]`` given by Earth mover's distance. |
339
|
|
|
""" |
340
|
|
|
|
341
|
|
|
if isinstance(pdds, np.ndarray): |
342
|
|
|
if len(pdds.shape) == 2: |
343
|
|
|
pdds = [pdds] |
344
|
|
|
|
345
|
|
|
if isinstance(pdds_, np.ndarray): |
346
|
|
|
if len(pdds_.shape) == 2: |
347
|
|
|
pdds_ = [pdds_] |
348
|
|
|
|
349
|
|
|
kwargs.pop('return_transport', None) |
350
|
|
|
|
351
|
|
|
if n_jobs is not None and n_jobs > 1: |
352
|
|
|
# TODO: put results into preallocated empty array in place |
|
|
|
|
353
|
|
|
dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)( |
354
|
|
|
delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds_[j]) |
355
|
|
|
for i in range(len(pdds)) for j in range(len(pdds_)) |
356
|
|
|
) |
357
|
|
|
dm = np.array(dm).reshape((len(pdds), len(pdds_))) |
358
|
|
|
|
359
|
|
|
else: |
360
|
|
|
n, m = len(pdds), len(pdds_) |
361
|
|
|
dm = np.empty((n, m)) |
362
|
|
|
if verbose: |
363
|
|
|
bar = ProgressBar(max_value=n * m) |
|
|
|
|
364
|
|
|
count = 0 |
365
|
|
|
for i in range(n): |
366
|
|
|
for j in range(m): |
367
|
|
|
dm[i, j] = EMD(pdds[i], pdds_[j], metric=metric, **kwargs) |
368
|
|
|
if verbose: |
369
|
|
|
count += 1 |
370
|
|
|
bar.update(count) |
371
|
|
|
return dm |
372
|
|
|
|
373
|
|
|
|
374
|
|
|
def PDD_pdist( |
375
|
|
|
pdds: List[np.ndarray], |
376
|
|
|
metric: str = 'chebyshev', |
377
|
|
|
n_jobs=None, |
378
|
|
|
verbose=0, |
379
|
|
|
backend='multiprocessing', |
380
|
|
|
**kwargs |
381
|
|
|
) -> np.ndarray: |
382
|
|
|
"""Compare a set of PDDs pairwise, returning a condensed distance matrix. |
383
|
|
|
Supports parallelisation via joblib. If using parallelisation, make sure to |
384
|
|
|
include a if __name__ == '__main__' guard around this function. |
385
|
|
|
|
386
|
|
|
Parameters |
387
|
|
|
---------- |
388
|
|
|
pdds : List[numpy.ndarray] |
389
|
|
|
A list of PDDs. |
390
|
|
|
metric : str or callable, default 'chebyshev' |
391
|
|
|
Usually PDD rows are compared with the Chebyshev/l-infinity distance. |
392
|
|
|
Can take any metric accepted by :func:`scipy.spatial.distance.pdist`. |
393
|
|
|
n_jobs : int, default None |
394
|
|
|
Maximum number of concurrent jobs for parallel processing with joblib. |
395
|
|
|
Set to -1 to use the maximum possible. Note that for small inputs (< 100), |
|
|
|
|
396
|
|
|
using parallel processing may be slower than the default n_jobs=None. |
397
|
|
|
verbose : int, default 0 |
398
|
|
|
Controls verbosity. If using parallel processing (n_jobs > 1), verbose is |
399
|
|
|
passed to :class:`joblib.Parallel`, where larger values = more verbosity. |
400
|
|
|
Otherwise, uses progressbar2 where the progress bar is either on or off. |
401
|
|
|
backend : str, default 'multiprocessing' |
402
|
|
|
Specifies the parallelization backend implementation. For a list of |
403
|
|
|
supported backends, see the backend argument of :class:`joblib.Parallel`. |
404
|
|
|
|
405
|
|
|
Returns |
406
|
|
|
------- |
407
|
|
|
numpy.ndarray |
408
|
|
|
Returns a condensed distance matrix. Collapses a square |
409
|
|
|
distance matrix into a vector just keeping the upper half. See |
410
|
|
|
:func:`scipy.spatial.distance.squareform` to convert to a square |
|
|
|
|
411
|
|
|
distance matrix or for more on condensed distance matrices. |
412
|
|
|
""" |
413
|
|
|
|
414
|
|
|
kwargs.pop('return_transport', None) |
415
|
|
|
|
416
|
|
|
if n_jobs is not None and n_jobs > 1: |
417
|
|
|
# TODO: put results into preallocated empty array in place |
|
|
|
|
418
|
|
|
dm = Parallel(backend=backend, n_jobs=n_jobs, verbose=verbose)( |
419
|
|
|
delayed(partial(EMD, metric=metric, **kwargs))(pdds[i], pdds[j]) |
420
|
|
|
for i, j in combinations(range(len(pdds)), 2) |
421
|
|
|
) |
422
|
|
|
dm = np.array(dm) |
423
|
|
|
|
|
|
|
|
424
|
|
|
else: |
425
|
|
|
m = len(pdds) |
426
|
|
|
cdm_len = (m * (m - 1)) // 2 |
427
|
|
|
cdm = np.empty(cdm_len, dtype=np.double) |
428
|
|
|
inds = ((i, j) for i in range(0, m - 1) for j in range(i + 1, m)) |
429
|
|
|
if verbose: |
430
|
|
|
bar = ProgressBar(max_value=cdm_len) |
|
|
|
|
431
|
|
|
for r, (i, j) in enumerate(inds): |
432
|
|
|
cdm[r] = EMD(pdds[i], pdds[j], metric=metric, **kwargs) |
433
|
|
|
if verbose: |
434
|
|
|
bar.update(r) |
435
|
|
|
return dm |
436
|
|
|
|
437
|
|
|
|
438
|
|
|
def emd( |
439
|
|
|
pdd: np.ndarray, |
440
|
|
|
pdd_: np.ndarray, |
441
|
|
|
metric: Optional[str] = 'chebyshev', |
442
|
|
|
return_transport: Optional[bool] = False, |
443
|
|
|
**kwargs): |
444
|
|
|
"""Alias for amd.EMD().""" |
445
|
|
|
return EMD(pdd, pdd_, metric=metric, return_transport=return_transport, **kwargs) |
446
|
|
|
|
447
|
|
|
|
448
|
|
|
def _unwrap_periodicset_list(psets_or_str, **reader_kwargs): |
449
|
|
|
"""Valid input for compare (PeriodicSet, path, refcode, lists of such) |
450
|
|
|
--> |
|
|
|
|
451
|
|
|
list of PeriodicSets""" |
452
|
|
|
|
453
|
|
|
if isinstance(psets_or_str, PeriodicSet): |
|
|
|
|
454
|
|
|
return [psets_or_str] |
455
|
|
|
elif isinstance(psets_or_str, list): |
456
|
|
|
return [s for item in psets_or_str for s in _extract_periodicsets(item, **reader_kwargs)] |
457
|
|
|
else: |
458
|
|
|
return _extract_periodicsets(psets_or_str, **reader_kwargs) |
459
|
|
|
|
460
|
|
|
|
461
|
|
|
def _extract_periodicsets(item, **reader_kwargs): |
462
|
|
|
"""str (path/refocde), file or PeriodicSet --> list of PeriodicSets.""" |
463
|
|
|
|
464
|
|
|
if isinstance(item, PeriodicSet): |
|
|
|
|
465
|
|
|
return [item] |
466
|
|
|
elif isinstance(item, str) and not os.path.isfile(item) and not os.path.isdir(item): |
467
|
|
|
reader_kwargs.pop('reader', None) |
468
|
|
|
return list(CSDReader(item, **reader_kwargs)) |
469
|
|
|
else: |
470
|
|
|
reader_kwargs.pop('families', None) |
471
|
|
|
reader_kwargs.pop('refcodes', None) |
472
|
|
|
return list(CifReader(item, **reader_kwargs)) |
473
|
|
|
|