Passed
Push — master ( c9cb02...43c976 )
by Daniel
01:44
created

amd.calculate.PDD_to_AMD()   A

Complexity

Conditions 1

Size

Total Lines 15
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 15
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Functions for calculating AMDs and PDDs (and SDDs) of periodic and finite sets.
2
"""
3
4
from typing import Union, Tuple
5
import collections
6
7
import numpy as np
8
import scipy.spatial
9
import scipy.special
10
11
from ._nearest_neighbours import nearest_neighbours, nearest_neighbours_minval, generate_concentric_cloud
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
Unused Code introduced by
Unused generate_concentric_cloud imported from _nearest_neighbours
Loading history...
12
from .periodicset import PeriodicSet
13
from .utils import diameter
14
15
PeriodicSet_or_Tuple = Union[PeriodicSet, Tuple[np.ndarray, np.ndarray]]
16
17
18
def AMD(periodic_set: PeriodicSet_or_Tuple, k: int) -> np.ndarray:
19
    """The AMD of a periodic set (crystal) up to `k`.
20
21
    Parameters
22
    ----------
23
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of ndarrays
24
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
25
        by a tuple (motif, cell) with coordinates in Cartesian form.
26
    k : int
27
        Length of AMD returned.
28
29
    Returns
30
    -------
31
    ndarray
32
        An ndarray of shape (k,), the AMD of ``periodic_set`` up to `k`.
33
34
    Examples
35
    --------
36
    Make list of AMDs with ``k=100`` for crystals in mycif.cif::
37
38
        amds = []
39
        for periodic_set in amd.CifReader('mycif.cif'):
40
            amds.append(amd.AMD(periodic_set, 100))
41
  
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
42
    Make list of AMDs with ``k=10`` for crystals in these CSD refcode families::
43
44
        amds = []
45
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
46
            amds.append(amd.AMD(periodic_set, 10))
47
48
    Manually pass a periodic set as a tuple (motif, cell)::
49
50
        # simple cubic lattice
51
        motif = np.array([[0,0,0]])
52
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
53
        cubic_amd = amd.AMD((motif, cell), 100)
54
    """
55
56
    motif, cell, asymmetric_unit, multiplicities = _extract_motif_and_cell(periodic_set)
57
    pdd, _, _ = nearest_neighbours(motif, cell, k, asymmetric_unit=asymmetric_unit)
58
    return np.average(pdd, axis=0, weights=multiplicities)
59
60
61
def PDD(
62
        periodic_set: PeriodicSet_or_Tuple,
63
        k: int,
64
        lexsort: bool = True,
65
        collapse: bool = True,
66
        collapse_tol: float = 1e-4
67
) -> np.ndarray:
68
    """The PDD of a periodic set (crystal) up to `k`.
69
70
    Parameters
71
    ----------
72
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of ndarrays
73
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
74
        by a tuple (motif, cell) with coordinates in Cartesian form.
75
    k : int
76
        Number of columns in the PDD (the returned matrix has an additional first
77
        column containing weights).
78
    lexsort : bool, optional
79
        Whether or not to lexicographically order the rows. Default True.
80
    collapse: bool, optional
81
        Whether or not to collapse identical rows (within a tolerance). Default True.
82
    collapse_tol: float
83
        If two rows have all entries closer than collapse_tol, they get collapsed.
84
        Default is 1e-4.
85
86
    Returns
87
    -------
88
    ndarray
89
        An ndarray with k+1 columns, the PDD of ``periodic_set`` up to `k`.
90
91
    Examples
92
    --------
93
    Make list of PDDs with ``k=100`` for crystals in mycif.cif::
94
95
        pdds = []
96
        for periodic_set in amd.CifReader('mycif.cif'):
97
            # do not lexicographically order rows
98
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
99
100
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode families::
101
102
        pdds = []
103
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
104
            # do not collapse rows
105
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
106
107
    Manually pass a periodic set as a tuple (motif, cell)::
108
109
        # simple cubic lattice
110
        motif = np.array([[0,0,0]])
111
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
112
        cubic_amd = amd.PDD((motif, cell), 100)
113
    """
114
115
    motif, cell, asymmetric_unit, multiplicities = _extract_motif_and_cell(periodic_set)
116
    dists, _, _ = nearest_neighbours(motif, cell, k, asymmetric_unit=asymmetric_unit)
117
118
    if multiplicities is None:
119
        weights = np.full((motif.shape[0], ), 1 / motif.shape[0])
120
    else:
121
        weights = multiplicities / np.sum(multiplicities)
122
123
    if collapse:
124
        weights, dists = _collapse_rows(weights, dists, collapse_tol)
125
126
    pdd = np.hstack((weights[:, None], dists))
127
128
    if lexsort:
129
        pdd = pdd[np.lexsort(np.rot90(dists))]
130
131
    return pdd
132
133
134
def PDD_to_AMD(pdd: np.ndarray) -> np.ndarray:
135
    """Calculates AMD from a PDD. Faster than computing both from scratch.
136
137
    Parameters
138
    ----------
139
    pdd : np.ndarray
140
        The PDD of a periodic set.
141
142
    Returns
143
    -------
144
    ndarray
145
        The AMD of the periodic set.
146
    """
147
148
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
149
150
151
def AMD_finite(motif: np.ndarray) -> np.ndarray:
152
    """The AMD of a finite point set (up to k = `len(motif) - 1`).
153
154
    Parameters
155
    ----------
156
    motif : ndarray
157
        Cartesian coordinates of points in a set. Shape (n_points, dimensions)
158
159
    Returns
160
    -------
161
    ndarray
162
        An vector length len(motif) - 1, the AMD of ``motif``.
163
164
    Examples
165
    --------
166
    Find AMD distance between finite trapezium and kite point sets::
167
168
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
169
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
170
171
        trap_amd = amd.AMD_finite(trapezium)
172
        kite_amd = amd.AMD_finite(kite)
173
174
        dist = amd.AMD_pdist(trap_amd, kite_amd)
175
    """
176
177
    dm = np.sort(scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(motif)), axis=-1)[:, 1:]
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
178
    return np.average(dm, axis=0)
179
180
181
def PDD_finite(
182
        motif: np.ndarray,
183
        lexsort: bool = True,
184
        collapse: bool = True,
185
        collapse_tol: float = 1e-4
186
) -> np.ndarray:
187
    """The PDD of a finite point set (up to k = `len(motif) - 1`).
188
189
    Parameters
190
    ----------
191
    motif : ndarray
192
        Cartesian coordinates of points in a set. Shape (n_points, dimensions)
193
    lexsort : bool, optional
194
        Whether or not to lexicographically order the rows. Default True.
195
    collapse: bool, optional
196
        Whether or not to collapse identical rows (within a tolerance). Default True.
197
    collapse_tol: float
198
        If two rows have all entries closer than collapse_tol, they get collapsed.
199
        Default is 1e-4.
200
201
    Returns
202
    -------
203
    ndarray
204
        An ndarray with len(motif) columns, the PDD of ``motif``.
205
206
    Examples
207
    --------
208
    Find PDD distance between finite trapezium and kite point sets::
209
210
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
211
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
212
213
        trap_pdd = amd.PDD_finite(trapezium)
214
        kite_pdd = amd.PDD_finite(kite)
215
216
        dist = amd.emd(trap_pdd, kite_pdd)
217
    """
218
219
    dm = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(motif))
220
    m = motif.shape[0]
221
    dists = np.sort(dm, axis=-1)[:, 1:]
222
    weights = np.full((m, ), 1 / m)
223
224
    if collapse:
225
        weights, dists = _collapse_rows(weights, dists, collapse_tol)
226
227
    pdd = np.hstack((weights[:, None], dists))
228
229
    if lexsort:
230
        pdd = pdd[np.lexsort(np.rot90(dists))]
231
232
    return pdd
233
234
235
def PDD_reconstructable(
236
        periodic_set: PeriodicSet_or_Tuple,
237
        lexsort: bool = True
238
) -> np.ndarray:
239
    """The PDD of a periodic set with `k` (no of columns) large enough such that
240
    the periodic set can be reconstructed from the PDD.
241
242
    Parameters
243
    ----------
244
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of ndarrays
245
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
246
        by a tuple (motif, cell) with coordinates in Cartesian form.
247
    k : int
248
        Number of columns in the PDD, plus one for the first column of weights.
249
    order : int
250
        Order of the PDD, default 1. See papers for a description of higher-order PDDs.
251
    lexsort : bool, optional
252
        Whether or not to lexicographically order the rows. Default True.
253
    collapse: bool, optional
254
        Whether or not to collapse identical rows (within a tolerance). Default True.
255
    collapse_tol: float
256
        If two rows have all entries closer than collapse_tol, they get collapsed.
257
        Default is 1e-4.
258
259
    Returns
260
    -------
261
    ndarray
262
        An ndarray with k+1 columns, the PDD of ``periodic_set`` up to `k`.
263
264
    Examples
265
    --------
266
    Make list of PDDs with ``k=100`` for crystals in mycif.cif::
267
268
        pdds = []
269
        for periodic_set in amd.CifReader('mycif.cif'):
270
            # do not lexicographically order rows
271
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
272
273
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode families::
274
275
        pdds = []
276
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
277
            # do not collapse rows
278
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
279
280
    Manually pass a periodic set as a tuple (motif, cell)::
281
282
        # simple cubic lattice
283
        motif = np.array([[0,0,0]])
284
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
285
        cubic_amd = amd.PDD((motif, cell), 100)
286
    """
287
288
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
289
    dims = cell.shape[0]
290
291
    if dims not in (2, 3):
292
        raise ValueError('Reconstructing from PDD only implemented for 2 and 3 dimensions')
293
294
    min_val = diameter(cell) * 2
295
    pdd = nearest_neighbours_minval(motif, cell, min_val)
296
297
    if lexsort:
298
        pdd = pdd[np.lexsort(np.rot90(pdd))]
299
300
    return pdd
301
302
303
def PPC(periodic_set: PeriodicSet_or_Tuple) -> float:
304
    r"""The point packing coefficient (PPC) of ``periodic_set``.
305
306
    The PPC is a constant of any periodic set determining the
307
    asymptotic behaviour of its AMD or PDD as :math:`k \rightarrow \infty`.
308
309
    As :math:`k \rightarrow \infty`, the ratio :math:`\text{AMD}_k / \sqrt[n]{k}`
310
    approaches the PPC (as does any row of its PDD).
311
312
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n` dimensions,
313
314
    .. math::
315
316
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
317
318
    where :math:`V_n` is the volume of a unit sphere in :math:`n` dimensions.
319
320
    Parameters
321
    ----------
322
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of
323
        ndarrays (motif, cell) representing the periodic set in Cartesian form.
324
325
    Returns
326
    -------
327
    float
328
        The PPC of ``periodic_set``.
329
    """
330
331
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
332
    m, n = motif.shape
333
    det = np.linalg.det(cell)
334
    t = (n - n % 2) / 2
335
    if n % 2 == 0:
336
        V = (np.pi ** t) / np.math.factorial(t)
337
    else:
338
        V = (2 * np.math.factorial(t) * (4 * np.pi) ** t) / np.math.factorial(n)
339
340
    return (det / (m * V)) ** (1./n)
341
342
343
def AMD_estimate(periodic_set: PeriodicSet_or_Tuple, k: int) -> np.ndarray:
344
    r"""Calculates an estimate of AMD based on the PPC, using the fact that
345
346
    .. math::
347
348
        \lim_{k\rightarrow\infty}\frac{\text{AMD}_k}{\sqrt[n]{k}} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
349
350
    where :math:`U` is the unit cell, :math:`m` is the number of motif points and
351
    :math:`V_n` is the volume of a unit sphere in :math:`n`-dimensional space.
352
    """
353
354
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
355
    n = motif.shape[1]
356
    c = PPC((motif, cell))
357
    return np.array([(x ** (1. / n)) * c for x in range(1, k + 1)])
358
359
360
def _extract_motif_and_cell(periodic_set: PeriodicSet_or_Tuple):
361
    """`periodic_set` is either a :class:`.periodicset.PeriodicSet`, or
362
    a tuple of ndarrays (motif, cell). If possible, extracts the asymmetric unit
363
    and wyckoff multiplicities and returns them, otherwise returns None.
364
    """
365
366
    asymmetric_unit, multiplicities = None, None
367
368
    if isinstance(periodic_set, PeriodicSet):
369
        motif, cell = periodic_set.motif, periodic_set.cell
370
371
        if 'asymmetric_unit' in periodic_set.tags and 'wyckoff_multiplicities' in periodic_set.tags:
372
            asymmetric_unit = periodic_set.asymmetric_unit
373
            multiplicities = periodic_set.wyckoff_multiplicities
374
375
    elif isinstance(periodic_set, np.ndarray):
376
        motif, cell = periodic_set, None
377
    else:
378
        motif, cell = periodic_set[0], periodic_set[1]
379
380
    return motif, cell, asymmetric_unit, multiplicities
381
382
383
def _collapse_rows(weights, dists, collapse_tol):
384
    """Given a vector `weights`, matrix `dists` and tolerance `collapse_tol`, collapse
385
    the identical rows of dists (if all entries in a row are within  `collapse_tol`)
386
    and collapse the same entires of `weights` (adding entries that merge).
387
    """
388
389
    diffs = np.abs(dists[:, None] - dists)
390
    overlapping = np.all(diffs <= collapse_tol, axis=-1)
391
392
    res = _group_overlapping_and_sum_weights(weights, overlapping)
393
    if res is not None:
394
        weights = res[0]
395
        dists = dists[res[1]]
396
397
    return weights, dists
398
399
400
def _group_overlapping_and_sum_weights(weights, overlapping):
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
401
    if np.triu(overlapping, 1).any():
402
        groups = {}
403
        group = 0
404
        for i, row in enumerate(overlapping):
405
            if i not in groups:
406
                groups[i] = group
407
                group += 1
408
409
            for j in np.argwhere(row).T[0]:
410
                groups[j] = groups[i]
411
412
        groupings = collections.defaultdict(list)
413
        for key, val in sorted(groups.items()):
414
            groupings[val].append(key)
415
416
        weights_ = []
417
        keep_inds = []
418
        for inds in groupings.values():
419
            keep_inds.append(inds[0])
420
            weights_.append(np.sum(weights[inds]))
421
        weights = np.array(weights_)
422
423
        return weights, keep_inds
424