Passed
Push — master ( 8c12c2...4daa36 )
by Daniel
07:46
created

amd.calculate   A

Complexity

Total Complexity 34

Size/Duplication

Total Lines 466
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 34
eloc 150
dl 0
loc 466
rs 9.68
c 0
b 0
f 0

10 Functions

Rating   Name   Duplication   Size   Complexity  
A PDD_to_AMD() 0 15 1
A PDD_reconstructable() 0 31 3
B PDD_finite() 0 76 6
A PPC() 0 39 2
B _collapse_into_groups() 0 25 6
A AMD_estimate() 0 19 1
A _average_columns_except_first() 0 13 3
B PDD() 0 114 8
A AMD() 0 57 3
A AMD_finite() 0 29 1
1
"""Functions for calculating the average minimum distance (AMD) and
2
point-wise distance distribution (PDD) isometric invariants of
3
periodic crystals and finite sets.
4
"""
5
6
import collections
7
from typing import Tuple, Union
8
9
import numpy as np
10
import numba
11
from scipy.spatial.distance import pdist, squareform
12
from scipy.special import factorial
13
14
from .periodicset import PeriodicSet
15
from ._nearest_neighbours import nearest_neighbours, nearest_neighbours_minval
16
from .utils import diameter
17
18
__all__ = [
19
    'AMD',
20
    'PDD',
21
    'PDD_to_AMD',
22
    'AMD_finite',
23
    'PDD_finite',
24
    'PDD_reconstructable',
25
    'PPC',
26
    'AMD_estimate'
27
]
28
29
30
def AMD(pset: PeriodicSet, k: int) -> np.ndarray:
31
    """Return the average minimum distance (AMD) of a periodic set
32
    (crystal).
33
34
    The AMD of a periodic set is a geometry based descriptor independent
35
    of choice of motif and unit cell. It is a vector, the (weighted)
36
    average of the :func:`PDD <.calculate.PDD>` matrix, which has one
37
    row for each (unique) point in the unit cell containing distances to
38
    k nearest neighbours.
39
40
    Parameters
41
    ----------
42
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
43
        A periodic set (crystal) consisting of a unit cell and motif of
44
        points.
45
    k : int
46
        The number of neighbouring points (atoms) considered for each
47
        point in the unit cell.
48
49
    Returns
50
    -------
51
    :class:`numpy.ndarray`
52
        A :class:`numpy.ndarray` shape ``(k, )``, the AMD of
53
        ``pset`` up to k.
54
55
    Examples
56
    --------
57
    Make list of AMDs with k = 100 for crystals in data.cif::
58
59
        amds = []
60
        for periodic_set in amd.CifReader('data.cif'):
61
            amds.append(amd.AMD(periodic_set, 100))
62
63
    Make list of AMDs with k = 10 for crystals in these CSD refcode families::
64
65
        amds = []
66
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
67
            amds.append(amd.AMD(periodic_set, 10))
68
69
    Manually create a periodic set as a tuple (motif, cell)::
70
71
        # simple cubic lattice
72
        motif = np.array([[0,0,0]])
73
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
74
        periodic_set = amd.PeriodicSet(motif, cell)
75
        cubic_amd = amd.AMD(periodic_set, 100)
76
    """
77
78
    if pset.asym_unit is None or pset.multiplicities is None:
79
        asym_unit = pset.motif
80
        weights = np.ones((asym_unit.shape[0], ), dtype=np.float64)
81
    else:
82
        asym_unit = pset.motif[pset.asym_unit]
83
        weights = pset.multiplicities
84
85
    dists = nearest_neighbours(pset.motif, pset.cell, asym_unit, k + 1)
86
    return _average_columns_except_first(dists, weights)
87
88
89
def PDD(
90
        pset: PeriodicSet,
91
        k: int,
92
        lexsort: bool = True,
93
        collapse: bool = True,
94
        collapse_tol: float = 1e-4,
95
        return_row_groups: bool = False
96
) -> Union[np.ndarray, Tuple[np.ndarray, list]]:
97
    """Return the point-wise distance distribution (PDD) of a periodic
98
    set (crystal).
99
100
    The PDD of a periodic set is a geometry based descriptor independent
101
    of choice of motif and unit cell. It is a matrix where each row
102
    corresponds to a point in the motif, containing a weight followed by
103
    distances to the k nearest neighbours of the point.
104
105
    Parameters
106
    ----------
107
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
108
        A periodic set (crystal) consisting of a unit cell and motif of
109
        points.
110
    k : int
111
        The number of neighbours considered for each atom (point) in the
112
        unit cell. The returned matrix has k + 1 columns, the first
113
        column for weights of rows.
114
    lexsort : bool, default True
115
        Lexicographically order the rows.
116
    collapse: bool, default True
117
        Collapse repeated rows (within tolerance ``collapse_tol``).
118
    collapse_tol: float, default 1e-4
119
        If two rows have all elements closer than ``collapse_tol``, they
120
        are merged and weights are given to rows in proportion to the
121
        number of times they appeared.
122
    return_row_groups: bool, default False
123
        If True, return a tuple ``(pdd, groups)`` where ``groups``
124
        contains information about which rows in ``pdd`` correspond to
125
        which points. If ``pset.asym_unit`` is None, then
126
        ``groups[i]`` contains indices of points in
127
        ``pset.motif`` corresponding to ``pdd[i]``. Otherwise,
128
        PDD rows correspond to points in the asymmetric unit, and
129
        ``groups[i]`` contains indices of points in
130
        ``pset.motif[pset.asym_unit]``.
131
132
    Returns
133
    -------
134
    pdd : :class:`numpy.ndarray`
135
        A :class:`numpy.ndarray` with k+1 columns, the PDD of
136
        ``pset`` up to k. The first column contains the weights
137
        of rows. If ``return_row_groups`` is True, returns a tuple with
138
        types (:class:`numpy.ndarray`, list).
139
140
    Examples
141
    --------
142
    Make list of PDDs with ``k=100`` for crystals in data.cif::
143
144
        pdds = []
145
        for periodic_set in amd.CifReader('data.cif'):
146
            # do not lexicographically order rows
147
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
148
149
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode
150
    families::
151
152
        pdds = []
153
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
154
            # do not collapse rows
155
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
156
157
    Manually create a periodic set as a tuple (motif, cell)::
158
159
        # simple cubic lattice
160
        motif = np.array([[0,0,0]])
161
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
162
        periodic_set = amd.PeriodicSet(motif, cell)
163
        cubic_amd = amd.PDD(periodic_set, 100)
164
    """
165
166
    m = pset.motif.shape[0]
167
    if pset.asym_unit is None or pset.multiplicities is None:
168
        asym_unit = pset.motif
169
        weights = np.full((m, ), 1 / m, dtype=np.float64)
170
    else:
171
        asym_unit = pset.motif[pset.asym_unit]
172
        weights = pset.multiplicities / m
173
174
    dists = nearest_neighbours(pset.motif, pset.cell, asym_unit, k + 1)
175
    dists = dists[:, 1:]
176
    groups = [[i] for i in range(len(dists))]
177
178
    if collapse:
179
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
180
        if overlapping.any():
181
            groups = _collapse_into_groups(overlapping)
182
            weights = np.array([np.sum(weights[group]) for group in groups])
183
            dists = np.array(
184
                [np.average(dists[group], axis=0) for group in groups],
185
                dtype=np.float64
186
            )
187
188
    pdd = np.empty(shape=(len(dists), k + 1), dtype=np.float64)
189
190
    if lexsort:
191
        lex_ordering = np.lexsort(np.rot90(dists))
192
        pdd[:, 0] = weights[lex_ordering]
193
        pdd[:, 1:] = dists[lex_ordering]
194
        if return_row_groups:
195
            groups = [groups[i] for i in lex_ordering]
196
    else:
197
        pdd[:, 0] = weights
198
        pdd[:, 1:] = dists
199
200
    if return_row_groups:
201
        return pdd, groups
202
    return pdd
203
204
205
def PDD_to_AMD(pdd: np.ndarray) -> np.ndarray:
206
    """Calculate an AMD from a PDD. Faster than computing both from
207
    scratch.
208
209
    Parameters
210
    ----------
211
    pdd : :class:`numpy.ndarray`
212
        The PDD of a periodic set.
213
214
    Returns
215
    -------
216
    :class:`numpy.ndarray`
217
        The AMD of the periodic set.
218
    """
219
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
220
221
222
def AMD_finite(motif: np.ndarray) -> np.ndarray:
223
    """Return the AMD of a finite m-point set up to k = m - 1.
224
225
    Parameters
226
    ----------
227
    motif : :class:`numpy.ndarray`
228
        Collection of points.
229
230
    Returns
231
    -------
232
    :class:`numpy.ndarray`
233
        A vector shape (motif.shape[0] - 1, ), the AMD of ``motif``.
234
235
    Examples
236
    --------
237
    The (L-infinity) AMD distance between finite trapezium and kite
238
    point sets::
239
240
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
241
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
242
243
        trap_amd = amd.AMD_finite(trapezium)
244
        kite_amd = amd.AMD_finite(kite)
245
246
        l_inf_dist = np.amax(np.abs(trap_amd - kite_amd))
247
    """
248
249
    dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
250
    return np.average(dm, axis=0)
251
252
253
def PDD_finite(
254
        motif: np.ndarray,
255
        lexsort: bool = True,
256
        collapse: bool = True,
257
        collapse_tol: float = 1e-4,
258
        return_row_groups: bool = False
259
) -> Union[np.ndarray, Tuple[np.ndarray, list]]:
260
    """Return the PDD of a finite m-point set up to k = m - 1.
261
262
    Parameters
263
    ----------
264
    motif : :class:`numpy.ndarray`
265
        Coordinates of a set of points.
266
    lexsort : bool, default True
267
        Whether or not to lexicographically order the rows.
268
    collapse: bool, default True
269
        Whether or not to collapse repeated rows (within tolerance
270
        ``collapse_tol``).
271
    collapse_tol: float, default 1e-4
272
        If two rows have all elements closer than ``collapse_tol``, they
273
        are merged and weights are given to rows in proportion to the
274
        number of times they appeared.
275
    return_row_groups: bool, default False
276
        If True, return a tuple ``(pdd, groups)`` where ``groups[i]``
277
        contains indices of points in ``motif`` corresponding to
278
        ``pdd[i]``.
279
280
    Returns
281
    -------
282
    pdd : :class:`numpy.ndarray`
283
        A :class:`numpy.ndarray` with m columns (where m is the number
284
        of points), the PDD of ``motif``. The first column contains the
285
        weights of rows.
286
287
    Examples
288
    --------
289
    Find PDD distance between finite trapezium and kite point sets::
290
291
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
292
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
293
294
        trap_pdd = amd.PDD_finite(trapezium)
295
        kite_pdd = amd.PDD_finite(kite)
296
297
        dist = amd.EMD(trap_pdd, kite_pdd)
298
    """
299
300
    m = motif.shape[0]
301
    dists = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
302
    weights = np.full((m, ), 1 / m)
303
    groups = [[i] for i in range(len(dists))]
304
305
    if collapse:
306
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
307
        if overlapping.any():
308
            groups = _collapse_into_groups(overlapping)
309
            weights = np.array([np.sum(weights[group]) for group in groups])
310
            dists = np.array([
311
                np.average(dists[group], axis=0) for group in groups
312
            ], dtype=np.float64)
313
314
    pdd = np.empty(shape=(len(weights), m), dtype=np.float64)
315
316
    if lexsort:
317
        lex_ordering = np.lexsort(np.rot90(dists))
318
        pdd[:, 0] = weights[lex_ordering]
319
        pdd[:, 1:] = dists[lex_ordering]
320
        if return_row_groups:
321
            groups = [groups[i] for i in lex_ordering]
322
    else:
323
        pdd[:, 0] = weights
324
        pdd[:, 1:] = dists
325
326
    if return_row_groups:
327
        return pdd, groups
328
    return pdd
329
330
331
def PDD_reconstructable(pset: PeriodicSet, lexsort: bool = True) -> np.ndarray:
332
    """Return the PDD of a periodic set with `k` (number of columns)
333
    large enough such that the periodic set can be reconstructed from
334
    the PDD.
335
336
    Parameters
337
    ----------
338
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
339
        A periodic set (crystal) consisting of a unit cell and motif of
340
        points.
341
    lexsort : bool, default True
342
        Whether or not to lexicographically order the rows.
343
344
    Returns
345
    -------
346
    pdd : :class:`numpy.ndarray`
347
        The PDD of ``pset`` with enough columns to reconstruct ``pset``
348
        using :func:`amd.reconstruct.reconstruct`.
349
    """
350
351
    dims = pset.cell.shape[0]
352
    if dims not in (2, 3):
353
        raise ValueError(
354
            'Reconstructing from PDD is only possible for 2 and 3 dimensions.'
355
        )
356
    min_val = diameter(pset.cell) * 2
357
    pdd, _, _ = nearest_neighbours_minval(pset.motif, pset.cell, min_val)
358
    if lexsort:
359
        lex_ordering = np.lexsort(np.rot90(pdd))
360
        pdd = pdd[lex_ordering]
361
    return pdd
362
363
364
def PPC(pset: PeriodicSet) -> float:
365
    r"""Return the point packing coefficient (PPC) of ``pset``.
366
367
    The PPC is a constant of any periodic set determining the
368
    asymptotic behaviour of its AMD and PDD. As
369
    :math:`k \rightarrow \infty`, the ratio
370
    :math:`\text{AMD}_k / \sqrt[n]{k}` converges to the PPC, as does any
371
    row of its PDD.
372
373
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n`
374
    dimensions,
375
376
    .. math::
377
378
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
379
380
    where :math:`V_n` is the volume of a unit sphere in :math:`n`
381
    dimensions.
382
383
    Parameters
384
    ----------
385
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
386
        A periodic set (crystal) consisting of a unit cell and motif of
387
        points.
388
389
    Returns
390
    -------
391
    ppc : float
392
        The PPC of ``pset``.
393
    """
394
395
    m, n = pset.motif.shape
396
    t = int(n // 2)
397
    if n % 2 == 0:
398
        sphere_vol = (np.pi ** t) / factorial(t)
399
    else:
400
        sphere_vol = (2 * factorial(t) * (4 * np.pi) ** t) / factorial(n)
401
402
    return (np.linalg.det(pset.cell) / (m * sphere_vol)) ** (1.0 / n)
403
404
405
def AMD_estimate(pset: PeriodicSet, k: int) -> np.ndarray:
406
    r"""Calculate an estimate of AMD based on the PPC.
407
408
    Parameters
409
    ----------
410
    pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`
411
        A periodic set (crystal) consisting of a unit cell and motif of
412
        points.
413
414
    Returns
415
    -------
416
    amd_est : :class:`numpy.ndarray`
417
        An array shape (k, ), where ``amd_est[i]``
418
        :math:`= \text{PPC} \sqrt[n]{k}` in n dimensions.
419
    """
420
421
    n = pset.cell.shape[0]
422
    k_root = np.power(np.arange(1, k + 1, dtype=np.float64), 1.0 / n)
423
    return PPC((pset.motif, pset.cell)) * k_root
424
425
426
@numba.njit(cache=True, fastmath=True)
427
def _average_columns_except_first(dists, weights):
428
    m, k = dists.shape
429
    k -= 1
430
    result = np.empty((k, ), dtype=np.float64)
431
    div = np.sum(weights)
432
    for j in range(k):
433
        av = 0
434
        for i in range(m):
435
            av += dists[i, j+1] * weights[i]
436
        result[j] = av
437
    result /= div
438
    return result
439
440
441
def _collapse_into_groups(overlapping: np.ndarray) -> list:
442
    """Return a list of groups of indices where all indices in the same
443
    group overlap. ``overlapping`` indicates for each pair of items in a
444
    set whether or not the items overlap, in the shape of a condensed
445
    distance matrix.
446
    """
447
448
    overlapping = squareform(overlapping)
449
    group_nums = {}
450
    group = 0
451
    for i, row in enumerate(overlapping):
452
        if i not in group_nums:
453
            group_nums[i] = group
454
            group += 1
455
456
            for j in np.argwhere(row).T[0]:
457
                if j not in group_nums:
458
                    group_nums[j] = group_nums[i]
459
460
    groups = collections.defaultdict(list)
461
    for row_ind, group_num in sorted(group_nums.items()):
462
        groups[group_num].append(row_ind)
463
    groups = list(groups.values())
464
465
    return groups
466