Test Failed
Push — master ( add8ad...560841 )
by Daniel
09:07
created

amd.calculate._get_structure()   A

Complexity

Conditions 4

Size

Total Lines 24
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 24
rs 9.65
c 0
b 0
f 0
cc 4
nop 1
1
"""Functions for calculating the average minimum distance (AMD) and
2
point-wise distance distribution (PDD) isometric invariants of
3
periodic crystals and finite sets.
4
"""
5
6
import collections
7
from typing import Tuple
8
9
import numpy as np
10
import numpy.typing as npt
11
from scipy.spatial.distance import pdist, squareform
12
13
from .periodicset import PeriodicSet, PeriodicSetType
14
from ._nns import nearest_neighbours, nearest_neighbours_minval
15
from .utils import diameter
16
17
18
def AMD(periodic_set: PeriodicSetType, k: int) -> npt.NDArray[np.float64]:
19
    """The AMD of a periodic set (crystal) up to k.
20
21
    Parameters
22
    ----------
23
    periodic_set :
24
        A periodic set represented by a 
25
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
26
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
27
        in Cartesian form and a square unit cell.
28
    k : int
29
        Length of the AMD returned; the number of neighbours considered
30
        for each atom in the unit cell to make the AMD.
31
32
    Returns
33
    -------
34
    :class:`numpy.ndarray`
35
        A :class:`numpy.ndarray` shape ``(k, )``, the AMD of
36
        ``periodic_set`` up to k.
37
38
    Examples
39
    --------
40
    Make list of AMDs with k = 100 for crystals in data.cif::
41
42
        amds = []
43
        for periodic_set in amd.CifReader('data.cif'):
44
            amds.append(amd.AMD(periodic_set, 100))
45
46
    Make list of AMDs with k = 10 for crystals in these CSD refcode families::
47
48
        amds = []
49
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
50
            amds.append(amd.AMD(periodic_set, 10))
51
52
    Manually pass a periodic set as a tuple (motif, cell)::
53
54
        # simple cubic lattice
55
        motif = np.array([[0,0,0]])
56
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
57
        cubic_amd = amd.AMD((motif, cell), 100)
58
    """
59
60
    motif, cell, asymmetric_unit, weights = _get_structure(periodic_set)
61
    dists, _, _ = nearest_neighbours(motif, cell, asymmetric_unit, k)
62
    return np.average(dists, axis=0, weights=weights)
63
64
65
def PDD(
66
        periodic_set: PeriodicSetType,
67
        k: int,
68
        lexsort: bool = True,
69
        collapse: bool = True,
70
        collapse_tol: float = 1e-4,
71
        return_row_groups: bool = False
72
) -> npt.NDArray[np.float64]:
73
    """The PDD of a periodic set (crystal) up to k.
74
75
    Parameters
76
    ----------
77
    periodic_set :
78
        A periodic set represented by a
79
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
80
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
81
        in Cartesian form and a square unit cell.
82
    k : int
83
        The number of neighbours considered for each atom in the unit
84
        cell. The returned matrix has k + 1 columns, the first column
85
        for weights of rows.
86
    lexsort : bool, default True
87
        Lexicographically order the rows.
88
    collapse: bool, default True
89
        Collapse repeated rows (within tolerance ``collapse_tol``).
90
    collapse_tol: float, default 1e-4
91
        If two rows have all elements closer than ``collapse_tol``, they
92
        are merged and weights are given to rows in proportion to the
93
        number of times they appeared.
94
    return_row_groups: bool, default False
95
        Return data about which PDD rows correspond to which points. If
96
        True, a tuple is returned ``(pdd, groups)`` where ``groups[i]``
97
        contains the indices of the point(s) corresponding to
98
        ``pdd[i]``. Note that these indices are for the asymmetric unit
99
        of the set, whose indices in ``periodic_set.motif`` are
100
        accessible through ``periodic_set.asymmetric_unit``.
101
102
    Returns
103
    -------
104
    pdd : :class:`numpy.ndarray`
105
        A :class:`numpy.ndarray` with k+1 columns, the PDD of
106
        ``periodic_set`` up to k. The first column contains the weights
107
        of rows. If ``return_row_groups`` is True, returns a tuple with
108
        types (:class:`numpy.ndarray`, list).
109
110
    Examples
111
    --------
112
    Make list of PDDs with ``k=100`` for crystals in data.cif::
113
114
        pdds = []
115
        for periodic_set in amd.CifReader('data.cif'):
116
            # do not lexicographically order rows
117
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
118
119
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode
120
    families::
121
122
        pdds = []
123
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
124
            # do not collapse rows
125
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
126
127
    Manually pass a periodic set as a tuple (motif, cell)::
128
129
        # simple cubic lattice
130
        motif = np.array([[0,0,0]])
131
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
132
        cubic_amd = amd.PDD((motif, cell), 100)
133
    """
134
135
    motif, cell, asymmetric_unit, weights = _get_structure(periodic_set)
136
    dists, _, _ = nearest_neighbours(motif, cell, asymmetric_unit, k)
137
    groups = [[i] for i in range(len(dists))]
138
139
    if collapse:
140
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
141
        if overlapping.any():
142
            groups = _collapse_into_groups(overlapping)
143
            weights = np.array([np.sum(weights[group]) for group in groups])
144
            dists = np.array([
145
                np.average(dists[group], axis=0) for group in groups
146
            ])
147
148
    pdd = np.empty(shape=(len(weights), k + 1))
149
    pdd[:, 0] = weights
150
    pdd[:, 1:] = dists
151
152
    if lexsort:
153
        lex_ordering = np.lexsort(np.rot90(dists))
154
        if return_row_groups:
155
            groups = [groups[i] for i in lex_ordering]
156
        pdd = pdd[lex_ordering]
157
158
    if return_row_groups:
159
        return pdd, groups
160
    return pdd
161
162
163
def PDD_to_AMD(pdd: npt.NDArray) -> npt.NDArray:
164
    """Calculates an AMD from a PDD. Faster than computing both from
165
    scratch.
166
167
    Parameters
168
    ----------
169
    pdd : :class:`numpy.ndarray`
170
        The PDD of a periodic set.
171
172
    Returns
173
    -------
174
    :class:`numpy.ndarray`
175
        The AMD of the periodic set.
176
    """
177
178
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
179
180
181
def AMD_finite(motif: npt.NDArray) -> npt.NDArray:
182
    """The AMD of a finite m-point set up to k = m - 1.
183
184
    Parameters
185
    ----------
186
    motif : :class:`numpy.ndarray`
187
        Coordinates of a set of points.
188
189
    Returns
190
    -------
191
    :class:`numpy.ndarray`
192
        A vector shape (motif.shape[0] - 1, ), the AMD of ``motif``.
193
194
    Examples
195
    --------
196
    The (L-infinity) AMD distance between finite trapezium and kite
197
    point sets::
198
199
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
200
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
201
202
        trap_amd = amd.AMD_finite(trapezium)
203
        kite_amd = amd.AMD_finite(kite)
204
205
        l_inf_dist = np.amax(np.abs(trap_amd - kite_amd))
206
    """
207
208
    dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
209
    return np.average(dm, axis=0)
210
211
212
def PDD_finite(
213
        motif: np.ndarray,
214
        lexsort: bool = True,
215
        collapse: bool = True,
216
        collapse_tol: float = 1e-4,
217
        return_row_groups: bool = False
218
) -> npt.NDArray[np.float64]:
219
    """The PDD of a finite m-point set up to k = m - 1.
220
221
    Parameters
222
    ----------
223
    motif : :class:`numpy.ndarray`
224
        Coordinates of a set of points.
225
    lexsort : bool, default True
226
        Whether or not to lexicographically order the rows.
227
    collapse: bool, default True
228
        Whether or not to collapse repeated rows (within tolerance
229
        ``collapse_tol``).
230
    collapse_tol: float, default 1e-4
231
        If two rows have all elements closer than ``collapse_tol``, they
232
        are merged and weights are given to rows in proportion to the
233
        number of times they appeared.
234
    return_row_groups: bool, default False
235
        Whether to return data about which PDD rows correspond to which
236
        points. If True, a tuple is returned ``(pdd, groups)`` where
237
        ``groups[i]`` contains the indices of the point(s) corresponding
238
        to ``pdd[i]``.
239
240
    Returns
241
    -------
242
    pdd : :class:`numpy.ndarray`
243
        A :class:`numpy.ndarray` with m columns (where m is the number
244
        of points), the PDD of ``motif``. The first column contains the
245
        weights of rows.
246
247
    Examples
248
    --------
249
    Find PDD distance between finite trapezium and kite point sets::
250
251
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
252
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
253
254
        trap_pdd = amd.PDD_finite(trapezium)
255
        kite_pdd = amd.PDD_finite(kite)
256
257
        dist = amd.EMD(trap_pdd, kite_pdd)
258
    """
259
260
    m = motif.shape[0]
261
    dists = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
262
    weights = np.full((m, ), 1 / m)
263
    groups = [[i] for i in range(len(dists))]
264
265
    if collapse:
266
        overlapping = pdist(dists, metric='chebyshev')
267
        overlapping = overlapping <= collapse_tol
268
        if overlapping.any():
269
            groups = _collapse_into_groups(overlapping)
270
            weights = np.array([np.sum(weights[group]) for group in groups])
271
            dists = np.array([
272
                np.average(dists[group], axis=0) for group in groups
273
            ])
274
275
    pdd = np.empty(shape=(len(weights), m))
276
    pdd[:, 0] = weights
277
    pdd[:, 1:] = dists
278
279
    if lexsort:
280
        lex_ordering = np.lexsort(np.rot90(dists))
281
        groups = [groups[i] for i in lex_ordering]
282
        pdd = pdd[lex_ordering]
283
284
    if return_row_groups:
285
        return pdd, groups
286
    return pdd
287
288
289
def PDD_reconstructable(
290
        periodic_set: PeriodicSetType,
291
        lexsort: bool = True
292
) -> npt.NDArray[np.float64]:
293
    """The PDD of a periodic set with `k` (number of columns) large
294
    enough such that the periodic set can be reconstructed from the PDD.
295
296
    Parameters
297
    ----------
298
    periodic_set :
299
        A periodic set represented by a
300
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
301
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
302
        in Cartesian form and a square unit cell.
303
    lexsort : bool, default True
304
        Whether or not to lexicographically order the rows.
305
306
    Returns
307
    -------
308
    pdd : :class:`numpy.ndarray`
309
        The PDD of ``periodic_set`` with enough columns to be
310
        reconstructable.
311
    """
312
313
    motif, cell, _, _ = _get_structure(periodic_set)
314
    dims = cell.shape[0]
315
316
    if dims not in (2, 3):
317
        msg = 'Reconstructing from PDD is only possible for 2 and 3 dimensions'
318
        raise ValueError(msg)
319
320
    min_val = diameter(cell) * 2
321
    pdd = nearest_neighbours_minval(motif, cell, min_val)
322
323
    if lexsort:
324
        pdd = pdd[np.lexsort(np.rot90(pdd))]
325
326
    return pdd
327
328
329
def PPC(periodic_set: PeriodicSetType) -> float:
330
    r"""The point packing coefficient (PPC) of ``periodic_set``.
331
332
    The PPC is a constant of any periodic set determining the
333
    asymptotic behaviour of its AMD and PDD. As 
334
    :math:`k \rightarrow \infty`, the ratio
335
    :math:`\text{AMD}_k / \sqrt[n]{k}` converges to the PPC, as does any
336
    row of its PDD.
337
338
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n`
339
    dimensions,
340
341
    .. math::
342
343
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
344
345
    where :math:`V_n` is the volume of a unit sphere in :math:`n`
346
    dimensions.
347
348
    Parameters
349
    ----------
350
    periodic_set :
351
        A periodic set represented by a
352
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
353
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
354
        in Cartesian form.
355
356
    Returns
357
    -------
358
    ppc : float
359
        The PPC of ``periodic_set``.
360
    """
361
362
    motif, cell, _, _ = _get_structure(periodic_set)
363
    m, n = motif.shape
364
    cell_volume = np.linalg.det(cell)
365
    t = int((n - n % 2) / 2)
366
367
    if n % 2 == 0:
368
        unit_sphere_vol = (np.pi ** t) / np.math.factorial(t)
369
    else:
370
        num = (2 * np.math.factorial(t) * (4 * np.pi) ** t)
371
        unit_sphere_vol = num / np.math.factorial(n)
372
373
    return (cell_volume / (m * unit_sphere_vol)) ** (1. / n)
374
375
376
def AMD_estimate(
377
        periodic_set: PeriodicSetType,
378
        k: int
379
) -> npt.NDArray[np.float64]:
380
    r"""Calculates an estimate of AMD based on the PPC.
381
382
    Parameters
383
    ----------
384
    periodic_set :
385
        A periodic set represented by a
386
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
387
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
388
        in Cartesian form.
389
390
    Returns
391
    -------
392
    amd_est : :class:`numpy.ndarray`
393
        An array shape (k, ), where ``amd_est[i]``
394
        :math:`= \text{PPC} \sqrt[n]{k}` in n dimensions.
395
    """
396
397
    motif, cell, _, _ = _get_structure(periodic_set)
398
    n = motif.shape[1]
399
    return PPC((motif, cell)) * np.power(np.arange(1, k + 1), 1. / n)
400
401
402
def _get_structure(periodic_set: PeriodicSetType) -> Tuple[npt.NDArray]:
403
    """``periodic_set`` is either a
404
    :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or a tuple of
405
    :class:`numpy.ndarray` s (motif, cell). If present, extracts the
406
    asymmetric unit and wyckoff multiplicities from periodic_set.
407
    """
408
409
    if isinstance(periodic_set, PeriodicSet):
410
        motif = periodic_set.motif
411
        cell = periodic_set.cell
412
        asym_unit_inds = periodic_set.asymmetric_unit
413
        wyc_muls = periodic_set.wyckoff_multiplicities
414
        if asym_unit_inds is None or wyc_muls is None:
415
            asymmetric_unit = motif
416
            weights = np.full((len(motif), ), 1 / len(motif))
417
        else:
418
            asymmetric_unit = motif[asym_unit_inds]
419
            weights = wyc_muls / np.sum(wyc_muls)
420
    else:
421
        motif, cell = periodic_set
422
        asymmetric_unit = motif
423
        weights = np.full((len(motif), ), 1 / len(motif))
424
425
    return motif, cell, asymmetric_unit, weights
426
427
428
def _collapse_into_groups(overlapping: npt.NDArray) -> list:
429
    """The vector ``overlapping`` indicates for each pair of items in a
430
    set whether or not the items overlap, in the shape of a condensed
431
    distance matrix. Returns a list of groups of indices where all items
432
    in the same group overlap.
433
    TODO: This function is not efficient.
434
    """
435
436
    overlapping = squareform(overlapping)
437
    group_nums = {} # row_ind: group number
438
    group = 0
439
    for i, row in enumerate(overlapping):
440
        if i not in group_nums:
441
            group_nums[i] = group
442
            group += 1
443
444
            for j in np.argwhere(row).T[0]:
445
                if j not in group_nums:
446
                    group_nums[j] = group_nums[i]
447
448
    groups = collections.defaultdict(list)
449
    for row_ind, group_num in sorted(group_nums.items()):
450
        groups[group_num].append(row_ind)
451
    groups = list(groups.values())
452
453
    return groups
454