Passed
Push — master ( 7f7a74...2c655b )
by Daniel
05:50
created

amd.calculate.PPC()   A

Complexity

Conditions 3

Size

Total Lines 43
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 43
rs 9.9
c 0
b 0
f 0
cc 3
nop 1
1
"""Functions for calculating the average minimum distance (AMD) and
2
point-wise distance distribution (PDD) isometric invariants of
3
periodic crystals and finite sets.
4
"""
5
6
import collections
7
from typing import Tuple, Union
8
9
import numpy as np
10
import numpy.typing as npt
11
import numba
12
from scipy.spatial.distance import pdist, squareform
13
from scipy.special import factorial
14
15
from .periodicset import PeriodicSet
16
from ._nearest_neighbors import nearest_neighbors, nearest_neighbors_minval
17
from .utils import diameter
18
19
FloatArray = npt.NDArray[np.floating]
20
21
__all__ = [
22
    'AMD',
23
    'PDD',
24
    'PDD_to_AMD',
25
    'AMD_finite',
26
    'PDD_finite',
27
    'PDD_reconstructable',
28
    'PPC',
29
    'AMD_estimate'
30
]
31
32
33
def AMD(pset: PeriodicSet, k: int) -> FloatArray:
34
    """Return the average minimum distance (AMD) of a periodic set
35
    (crystal).
36
37
    The AMD of a periodic set is a geometry based descriptor independent
38
    of choice of motif and unit cell. It is a vector, the (weighted)
39
    average of the :func:`PDD <.calculate.PDD>` matrix, which has one
40
    row for each (unique) point in the unit cell containing distances to
41
    k nearest neighbors.
42
43
    Parameters
44
    ----------
45
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
46
        A periodic set (crystal) consisting of a unit cell and motif of
47
        points.
48
    k : int
49
        The number of neighboring points (atoms) considered for each
50
        point in the unit cell.
51
52
    Returns
53
    -------
54
    :class:`numpy.ndarray`
55
        A :class:`numpy.ndarray` shape ``(k, )``, the AMD of
56
        ``pset`` up to k.
57
58
    Examples
59
    --------
60
    Make list of AMDs with k = 100 for crystals in data.cif::
61
62
        amds = []
63
        for periodic_set in amd.CifReader('data.cif'):
64
            amds.append(amd.AMD(periodic_set, 100))
65
66
    Make list of AMDs with k = 10 for crystals in these CSD refcode families::
67
68
        amds = []
69
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
70
            amds.append(amd.AMD(periodic_set, 10))
71
72
    Manually create a periodic set as a tuple (motif, cell)::
73
74
        # simple cubic lattice
75
        motif = np.array([[0,0,0]])
76
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
77
        periodic_set = amd.PeriodicSet(motif, cell)
78
        cubic_amd = amd.AMD(periodic_set, 100)
79
    """
80
81
    if not isinstance(pset, PeriodicSet):
82
        raise ValueError(
83
            f'Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}'
84
        )
85
86
    if pset.asym_unit is None or pset.multiplicities is None:
87
        asym_unit = pset.motif
88
        weights = np.ones((asym_unit.shape[0], ), dtype=np.float64)
89
    else:
90
        asym_unit = pset.motif[pset.asym_unit]
91
        weights = pset.multiplicities
92
93
    dists = nearest_neighbors(pset.motif, pset.cell, asym_unit, k + 1)
94
    return _average_columns_except_first(dists, weights)
95
96
97
def PDD(
98
        pset: PeriodicSet,
99
        k: int,
100
        lexsort: bool = True,
101
        collapse: bool = True,
102
        collapse_tol: float = 1e-4,
103
        return_row_groups: bool = False
104
) -> Union[FloatArray, Tuple[FloatArray, list]]:
105
    """Return the point-wise distance distribution (PDD) of a periodic
106
    set (crystal).
107
108
    The PDD of a periodic set is a geometry based descriptor independent
109
    of choice of motif and unit cell. It is a matrix where each row
110
    corresponds to a point in the motif, containing a weight followed by
111
    distances to the k nearest neighbors of the point.
112
113
    Parameters
114
    ----------
115
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
116
        A periodic set (crystal) consisting of a unit cell and motif of
117
        points.
118
    k : int
119
        The number of neighbors considered for each atom (point) in the
120
        unit cell. The returned matrix has k + 1 columns, the first
121
        column for weights of rows.
122
    lexsort : bool, default True
123
        Lexicographically order the rows.
124
    collapse: bool, default True
125
        Collapse repeated rows (within tolerance ``collapse_tol``).
126
    collapse_tol: float, default 1e-4
127
        If two rows have all elements closer than ``collapse_tol``, they
128
        are merged and weights are given to rows in proportion to the
129
        number of times they appeared.
130
    return_row_groups: bool, default False
131
        If True, return a tuple ``(pdd, groups)`` where ``groups``
132
        contains information about which rows in ``pdd`` correspond to
133
        which points. If ``pset.asym_unit`` is None, then
134
        ``groups[i]`` contains indices of points in
135
        ``pset.motif`` corresponding to ``pdd[i]``. Otherwise,
136
        PDD rows correspond to points in the asymmetric unit, and
137
        ``groups[i]`` contains indices of points in
138
        ``pset.motif[pset.asym_unit]``.
139
140
    Returns
141
    -------
142
    pdd : :class:`numpy.ndarray`
143
        A :class:`numpy.ndarray` with k+1 columns, the PDD of
144
        ``pset`` up to k. The first column contains the weights
145
        of rows. If ``return_row_groups`` is True, returns a tuple with
146
        types (:class:`numpy.ndarray`, list).
147
148
    Examples
149
    --------
150
    Make list of PDDs with ``k=100`` for crystals in data.cif::
151
152
        pdds = []
153
        for periodic_set in amd.CifReader('data.cif'):
154
            # do not lexicographically order rows
155
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
156
157
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode
158
    families::
159
160
        pdds = []
161
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
162
            # do not collapse rows
163
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
164
165
    Manually create a periodic set as a tuple (motif, cell)::
166
167
        # simple cubic lattice
168
        motif = np.array([[0,0,0]])
169
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
170
        periodic_set = amd.PeriodicSet(motif, cell)
171
        cubic_amd = amd.PDD(periodic_set, 100)
172
    """
173
174
    if not isinstance(pset, PeriodicSet):
175
        raise ValueError(
176
            f'Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}'
177
        )
178
179
    m = pset.motif.shape[0]
180
    if pset.asym_unit is None or pset.multiplicities is None:
181
        asym_unit = pset.motif
182
        weights = np.full((m, ), 1 / m, dtype=np.float64)
183
    else:
184
        asym_unit = pset.motif[pset.asym_unit]
185
        weights = pset.multiplicities / m
186
187
    dists = nearest_neighbors(pset.motif, pset.cell, asym_unit, k + 1)
188
    dists = dists[:, 1:]
189
    groups = [[i] for i in range(len(dists))]
190
191
    if collapse:
192
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
193
        if overlapping.any():
194
            groups = _collapse_into_groups(overlapping)
195
            weights = np.array([np.sum(weights[group]) for group in groups])
196
            dists = np.array(
197
                [np.average(dists[group], axis=0) for group in groups],
198
                dtype=np.float64
199
            )
200
201
    pdd = np.empty(shape=(len(dists), k + 1), dtype=np.float64)
202
203
    if lexsort:
204
        lex_ordering = np.lexsort(np.rot90(dists))
205
        pdd[:, 0] = weights[lex_ordering]
206
        pdd[:, 1:] = dists[lex_ordering]
207
        if return_row_groups:
208
            groups = [groups[i] for i in lex_ordering]
209
    else:
210
        pdd[:, 0] = weights
211
        pdd[:, 1:] = dists
212
213
    if return_row_groups:
214
        return pdd, groups
215
    return pdd
216
217
218
def PDD_to_AMD(pdd: FloatArray) -> FloatArray:
219
    """Calculate an AMD from a PDD. Faster than computing both from
220
    scratch.
221
222
    Parameters
223
    ----------
224
    pdd : :class:`numpy.ndarray`
225
        The PDD of a periodic set, so ``amd.PDD_to_AMD(amd.PDD(pset))``
226
        equals ``amd.AMD(pset)``.
227
228
    Returns
229
    -------
230
    :class:`numpy.ndarray`
231
        The AMD of the periodic set.
232
    """
233
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
234
235
236
def AMD_finite(motif: FloatArray) -> FloatArray:
237
    """Return the AMD of a finite m-point set up to k = m - 1.
238
239
    Parameters
240
    ----------
241
    motif : :class:`numpy.ndarray`
242
        Collection of points.
243
244
    Returns
245
    -------
246
    :class:`numpy.ndarray`
247
        A vector shape (motif.shape[0] - 1, ), the AMD of ``motif``.
248
249
    Examples
250
    --------
251
    The (L-infinity) AMD distance between finite trapezium and kite
252
    point sets::
253
254
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
255
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
256
257
        trap_amd = amd.AMD_finite(trapezium)
258
        kite_amd = amd.AMD_finite(kite)
259
260
        l_inf_dist = np.amax(np.abs(trap_amd - kite_amd))
261
    """
262
263
    dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
264
    return np.average(dm, axis=0)
265
266
267
def PDD_finite(
268
        motif: FloatArray,
269
        lexsort: bool = True,
270
        collapse: bool = True,
271
        collapse_tol: float = 1e-4,
272
        return_row_groups: bool = False
273
) -> Union[FloatArray, Tuple[FloatArray, list]]:
274
    """Return the PDD of a finite m-point set up to k = m - 1.
275
276
    Parameters
277
    ----------
278
    motif : :class:`numpy.ndarray`
279
        Coordinates of a set of points.
280
    lexsort : bool, default True
281
        Whether or not to lexicographically order the rows.
282
    collapse: bool, default True
283
        Whether or not to collapse repeated rows (within tolerance
284
        ``collapse_tol``).
285
    collapse_tol: float, default 1e-4
286
        If two rows have all elements closer than ``collapse_tol``, they
287
        are merged and weights are given to rows in proportion to the
288
        number of times they appeared.
289
    return_row_groups: bool, default False
290
        If True, return a tuple ``(pdd, groups)`` where ``groups[i]``
291
        contains indices of points in ``motif`` corresponding to
292
        ``pdd[i]``.
293
294
    Returns
295
    -------
296
    pdd : :class:`numpy.ndarray`
297
        A :class:`numpy.ndarray` with m columns (where m is the number
298
        of points), the PDD of ``motif``. The first column contains the
299
        weights of rows.
300
301
    Examples
302
    --------
303
    Find PDD distance between finite trapezium and kite point sets::
304
305
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
306
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
307
308
        trap_pdd = amd.PDD_finite(trapezium)
309
        kite_pdd = amd.PDD_finite(kite)
310
311
        dist = amd.EMD(trap_pdd, kite_pdd)
312
    """
313
314
    m = motif.shape[0]
315
    dists = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
316
    weights = np.full((m, ), 1 / m)
317
    groups = [[i] for i in range(len(dists))]
318
319
    if collapse:
320
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
321
        if overlapping.any():
322
            groups = _collapse_into_groups(overlapping)
323
            weights = np.array([np.sum(weights[group]) for group in groups])
324
            dists = np.array([
325
                np.average(dists[group], axis=0) for group in groups
326
            ], dtype=np.float64)
327
328
    pdd = np.empty(shape=(len(weights), m), dtype=np.float64)
329
330
    if lexsort:
331
        lex_ordering = np.lexsort(np.rot90(dists))
332
        pdd[:, 0] = weights[lex_ordering]
333
        pdd[:, 1:] = dists[lex_ordering]
334
        if return_row_groups:
335
            groups = [groups[i] for i in lex_ordering]
336
    else:
337
        pdd[:, 0] = weights
338
        pdd[:, 1:] = dists
339
340
    if return_row_groups:
341
        return pdd, groups
342
    return pdd
343
344
345
def PDD_reconstructable(pset: PeriodicSet, lexsort: bool = True) -> FloatArray:
346
    """Return the PDD of a periodic set with `k` (number of columns)
347
    large enough such that the periodic set can be reconstructed from
348
    the PDD.
349
350
    Parameters
351
    ----------
352
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
353
        A periodic set (crystal) consisting of a unit cell and motif of
354
        points.
355
    lexsort : bool, default True
356
        Whether or not to lexicographically order the rows.
357
358
    Returns
359
    -------
360
    pdd : :class:`numpy.ndarray`
361
        The PDD of ``pset`` with enough columns to reconstruct ``pset``
362
        using :func:`amd.reconstruct.reconstruct`.
363
    """
364
365
    if not isinstance(pset, PeriodicSet):
366
        raise ValueError(
367
            f'Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}'
368
        )
369
370
    dims = pset.cell.shape[0]
371
    if dims not in (2, 3):
372
        raise ValueError(
373
            'Reconstructing from PDD is only possible for 2 and 3 dimensions.'
374
        )
375
    min_val = diameter(pset.cell) * 2
376
    pdd, _, _ = nearest_neighbors_minval(pset.motif, pset.cell, min_val)
377
    if lexsort:
378
        lex_ordering = np.lexsort(np.rot90(pdd))
379
        pdd = pdd[lex_ordering]
380
    return pdd
381
382
383
def PPC(pset: PeriodicSet) -> float:
384
    r"""Return the point packing coefficient (PPC) of ``pset``.
385
386
    The PPC is a constant of any periodic set determining the
387
    asymptotic behaviour of its AMD and PDD. As
388
    :math:`k \rightarrow \infty`, the ratio
389
    :math:`\text{AMD}_k / \sqrt[n]{k}` converges to the PPC, as does any
390
    row of its PDD.
391
392
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n`
393
    dimensions,
394
395
    .. math::
396
397
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
398
399
    where :math:`V_n` is the volume of a unit sphere in :math:`n`
400
    dimensions.
401
402
    Parameters
403
    ----------
404
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
405
        A periodic set (crystal) consisting of a unit cell and motif of
406
        points.
407
408
    Returns
409
    -------
410
    ppc : float
411
        The PPC of ``pset``.
412
    """
413
414
    if not isinstance(pset, PeriodicSet):
415
        raise ValueError(
416
            f'Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}'
417
        )
418
419
    m, n = pset.motif.shape
420
    t = int(n // 2)
421
    if n % 2 == 0:
422
        sphere_vol = (np.pi ** t) / factorial(t)
423
    else:
424
        sphere_vol = (2 * factorial(t) * (4 * np.pi) ** t) / factorial(n)
425
    return (np.abs(np.linalg.det(pset.cell)) / (m * sphere_vol)) ** (1.0 / n)
426
427
428
def AMD_estimate(pset: PeriodicSet, k: int) -> FloatArray:
429
    r"""Calculate an estimate of AMD based on the PPC.
430
431
    Parameters
432
    ----------
433
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
434
        A periodic set (crystal) consisting of a unit cell and motif of
435
        points.
436
437
    Returns
438
    -------
439
    amd_est : :class:`numpy.ndarray`
440
        An array shape (k, ), where ``amd_est[i]``
441
        :math:`= \text{PPC} \sqrt[n]{k}` in n dimensions.
442
    """
443
444
    if not isinstance(pset, PeriodicSet):
445
        raise ValueError(
446
            f'Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}'
447
        )
448
    n = pset.cell.shape[0]
449
    return PPC(pset) * np.power(np.arange(1, k + 1, dtype=np.float64), 1.0 / n)
450
451
452
@numba.njit(cache=True, fastmath=True)
453
def _average_columns_except_first(dists, weights):
454
    m, k = dists.shape
455
    result = np.zeros((k - 1, ), dtype=np.float64)
456
    for i in range(m):
457
        for j in range(1, k):
458
            result[j - 1] += dists[i, j] * weights[i]
459
    return result / np.sum(weights)
460
461
462
def _collapse_into_groups(overlapping: npt.NDArray[np.bool_]) -> list:
463
    """Return a list of groups of indices where all indices in the same
464
    group overlap. ``overlapping`` indicates for each pair of items in a
465
    set whether or not the items overlap, in the shape of a condensed
466
    distance matrix.
467
    """
468
469
    overlapping = squareform(overlapping)
470
    group_nums = {}
471
    group = 0
472
    for i, row in enumerate(overlapping):
473
        if i not in group_nums:
474
            group_nums[i] = group
475
            group += 1
476
477
            for j in np.argwhere(row).T[0]:
478
                if j not in group_nums:
479
                    group_nums[j] = group_nums[i]
480
481
    groups = collections.defaultdict(list)
482
    for row_ind, group_num in sorted(group_nums.items()):
483
        groups[group_num].append(row_ind)
484
    groups = list(groups.values())
485
486
    return groups
487