Passed
Push — master ( d05325...baece5 )
by Daniel
03:54
created

amd.calculate.PDD()   B

Complexity

Conditions 6

Size

Total Lines 97
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 97
rs 8.2266
c 0
b 0
f 0
cc 6
nop 6

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Functions for calculating the average minimum distance (AMD) and
2
point-wise distance distribution (PDD) isometric invariants of
3
periodic crystals and finite sets.
4
"""
5
6
import collections
7
from typing import Tuple, Union
8
9
import numpy as np
10
import numpy.typing as npt
11
from scipy.spatial.distance import pdist, squareform
12
from scipy.special import factorial
13
14
from .periodicset import PeriodicSet
15
from ._nearest_neighbours import nearest_neighbours, nearest_neighbours_minval
16
from .utils import diameter
17
18
__all__ = [
19
    'AMD', 'PDD', 'PDD_to_AMD', 'AMD_finite', 'PDD_finite',
20
    'PDD_reconstructable', 'PPC', 'AMD_estimate'
21
]
22
23
PeriodicSetType = Union[PeriodicSet, Tuple[npt.NDArray, npt.NDArray]]
24
25
26
def AMD(periodic_set: PeriodicSetType, k: int) -> npt.NDArray[np.float64]:
27
    """Return the AMD of a periodic set (crystal) up to k.
28
29
    Parameters
30
    ----------
31
    periodic_set :
32
        A periodic set represented by a
33
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
34
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
35
        in Cartesian form and a square unit cell.
36
    k : int
37
        Length of the AMD returned; the number of neighbours considered
38
        for each atom in the unit cell to make the AMD.
39
40
    Returns
41
    -------
42
    :class:`numpy.ndarray`
43
        A :class:`numpy.ndarray` shape ``(k, )``, the AMD of
44
        ``periodic_set`` up to k.
45
46
    Examples
47
    --------
48
    Make list of AMDs with k = 100 for crystals in data.cif::
49
50
        amds = []
51
        for periodic_set in amd.CifReader('data.cif'):
52
            amds.append(amd.AMD(periodic_set, 100))
53
54
    Make list of AMDs with k = 10 for crystals in these CSD refcode families::
55
56
        amds = []
57
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
58
            amds.append(amd.AMD(periodic_set, 10))
59
60
    Manually pass a periodic set as a tuple (motif, cell)::
61
62
        # simple cubic lattice
63
        motif = np.array([[0,0,0]])
64
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
65
        cubic_amd = amd.AMD((motif, cell), 100)
66
    """
67
68
    motif, cell, asymmetric_unit, weights = _get_structure(periodic_set)
69
    dists, _, _ = nearest_neighbours(motif, cell, asymmetric_unit, k + 1)
70
    return np.average(dists[:, 1:], axis=0, weights=weights)
71
72
73
def PDD(
74
        periodic_set: PeriodicSetType,
75
        k: int,
76
        lexsort: bool = True,
77
        collapse: bool = True,
78
        collapse_tol: float = 1e-4,
79
        return_row_groups: bool = False
80
) -> Union[npt.NDArray[np.float64], Tuple[npt.NDArray[np.float64], list]]:
81
    """Return the PDD of a periodic set (crystal) up to k.
82
83
    Parameters
84
    ----------
85
    periodic_set :
86
        A periodic set represented by a
87
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
88
        tuple of :class:`numpy.ndarray` s (motif, cell) with Cartesian
89
        coordinates and a square matrix representing the unit cell.
90
    k : int
91
        The number of neighbours considered for each atom (point) in the
92
        unit cell. The returned matrix has k + 1 columns, the first
93
        column for weights of rows.
94
    lexsort : bool, default True
95
        Lexicographically order the rows.
96
    collapse: bool, default True
97
        Collapse repeated rows (within tolerance ``collapse_tol``).
98
    collapse_tol: float, default 1e-4
99
        If two rows have all elements closer than ``collapse_tol``, they
100
        are merged and weights are given to rows in proportion to the
101
        number of times they appeared.
102
    return_row_groups: bool, default False
103
        Return data about which PDD rows correspond to which points. If
104
        True, a tuple is returned ``(pdd, groups)`` where ``groups[i]``
105
        contains the indices of the point(s) corresponding to
106
        ``pdd[i]``. Note that these indices are for the asymmetric unit
107
        of the set, whose indices in ``periodic_set.motif`` are
108
        accessible through ``periodic_set.asymmetric_unit``.
109
110
    Returns
111
    -------
112
    pdd : :class:`numpy.ndarray`
113
        A :class:`numpy.ndarray` with k+1 columns, the PDD of
114
        ``periodic_set`` up to k. The first column contains the weights
115
        of rows. If ``return_row_groups`` is True, returns a tuple with
116
        types (:class:`numpy.ndarray`, list).
117
118
    Examples
119
    --------
120
    Make list of PDDs with ``k=100`` for crystals in data.cif::
121
122
        pdds = []
123
        for periodic_set in amd.CifReader('data.cif'):
124
            # do not lexicographically order rows
125
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
126
127
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode
128
    families::
129
130
        pdds = []
131
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
132
            # do not collapse rows
133
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
134
135
    Manually pass a periodic set as a tuple (motif, cell)::
136
137
        # simple cubic lattice
138
        motif = np.array([[0,0,0]])
139
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
140
        cubic_amd = amd.PDD((motif, cell), 100)
141
    """
142
143
    motif, cell, asymmetric_unit, weights = _get_structure(periodic_set)
144
    dists, _, _ = nearest_neighbours(motif, cell, asymmetric_unit, k + 1)
145
    dists = dists[:, 1:]
146
    groups = [[i] for i in range(len(dists))]
147
148
    if collapse:
149
        overlapping = pdist(dists, metric='chebyshev') <= collapse_tol
150
        if overlapping.any():
151
            groups = _collapse_into_groups(overlapping)
152
            weights = np.array([np.sum(weights[group]) for group in groups])
153
            dists = np.array([
154
                np.average(dists[group], axis=0) for group in groups
155
            ], dtype=np.float64)
156
157
    pdd = np.empty(shape=(len(weights), k + 1), dtype=np.float64)
158
    pdd[:, 0] = weights
159
    pdd[:, 1:] = dists
160
161
    if lexsort:
162
        lex_ordering = np.lexsort(np.rot90(dists))
163
        if return_row_groups:
164
            groups = [groups[i] for i in lex_ordering]
165
        pdd = pdd[lex_ordering]
166
167
    if return_row_groups:
168
        return pdd, groups
169
    return pdd
170
171
172
def PDD_to_AMD(pdd: npt.NDArray) -> npt.NDArray:
173
    """Calculate an AMD from a PDD. Faster than computing both from
174
    scratch.
175
176
    Parameters
177
    ----------
178
    pdd : :class:`numpy.ndarray`
179
        The PDD of a periodic set.
180
181
    Returns
182
    -------
183
    :class:`numpy.ndarray`
184
        The AMD of the periodic set.
185
    """
186
187
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
188
189
190
def AMD_finite(motif: npt.NDArray) -> npt.NDArray:
191
    """Return the AMD of a finite m-point set up to k = m - 1.
192
193
    Parameters
194
    ----------
195
    motif : :class:`numpy.ndarray`
196
        Coordinates of a set of points.
197
198
    Returns
199
    -------
200
    :class:`numpy.ndarray`
201
        A vector shape (motif.shape[0] - 1, ), the AMD of ``motif``.
202
203
    Examples
204
    --------
205
    The (L-infinity) AMD distance between finite trapezium and kite
206
    point sets::
207
208
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
209
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
210
211
        trap_amd = amd.AMD_finite(trapezium)
212
        kite_amd = amd.AMD_finite(kite)
213
214
        l_inf_dist = np.amax(np.abs(trap_amd - kite_amd))
215
    """
216
217
    dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
218
    return np.average(dm, axis=0)
219
220
221
def PDD_finite(
222
        motif: np.ndarray,
223
        lexsort: bool = True,
224
        collapse: bool = True,
225
        collapse_tol: float = 1e-4,
226
        return_row_groups: bool = False
227
) -> Union[npt.NDArray[np.float64], Tuple[npt.NDArray[np.float64], list]]:
228
    """Return the PDD of a finite m-point set up to k = m - 1.
229
230
    Parameters
231
    ----------
232
    motif : :class:`numpy.ndarray`
233
        Coordinates of a set of points.
234
    lexsort : bool, default True
235
        Whether or not to lexicographically order the rows.
236
    collapse: bool, default True
237
        Whether or not to collapse repeated rows (within tolerance
238
        ``collapse_tol``).
239
    collapse_tol: float, default 1e-4
240
        If two rows have all elements closer than ``collapse_tol``, they
241
        are merged and weights are given to rows in proportion to the
242
        number of times they appeared.
243
    return_row_groups: bool, default False
244
        Whether to return data about which PDD rows correspond to which
245
        points. If True, a tuple is returned ``(pdd, groups)`` where
246
        ``groups[i]`` contains the indices of the point(s) corresponding
247
        to ``pdd[i]``.
248
249
    Returns
250
    -------
251
    pdd : :class:`numpy.ndarray`
252
        A :class:`numpy.ndarray` with m columns (where m is the number
253
        of points), the PDD of ``motif``. The first column contains the
254
        weights of rows.
255
256
    Examples
257
    --------
258
    Find PDD distance between finite trapezium and kite point sets::
259
260
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
261
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
262
263
        trap_pdd = amd.PDD_finite(trapezium)
264
        kite_pdd = amd.PDD_finite(kite)
265
266
        dist = amd.EMD(trap_pdd, kite_pdd)
267
    """
268
269
    m = motif.shape[0]
270
    dists = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
271
    weights = np.full((m, ), 1 / m)
272
    groups = [[i] for i in range(len(dists))]
273
274
    if collapse:
275
        overlapping = pdist(dists, metric='chebyshev')
276
        overlapping = overlapping <= collapse_tol
277
        if overlapping.any():
278
            groups = _collapse_into_groups(overlapping)
279
            weights = np.array([np.sum(weights[group]) for group in groups])
280
            dists = np.array([
281
                np.average(dists[group], axis=0) for group in groups
282
            ])
283
284
    pdd = np.empty(shape=(len(weights), m), dtype=np.float64)
285
    pdd[:, 0] = weights
286
    pdd[:, 1:] = dists
287
288
    if lexsort:
289
        lex_ordering = np.lexsort(np.rot90(dists))
290
        groups = [groups[i] for i in lex_ordering]
291
        pdd = pdd[lex_ordering]
292
293
    if return_row_groups:
294
        return pdd, groups
295
    return pdd
296
297
298
def PDD_reconstructable(
299
        periodic_set: PeriodicSetType,
300
        lexsort: bool = True
301
) -> npt.NDArray[np.float64]:
302
    """Return the PDD of a periodic set with `k` (number of columns)
303
    large enough such that the periodic set can be reconstructed from
304
    the PDD.
305
306
    Parameters
307
    ----------
308
    periodic_set :
309
        A periodic set represented by a
310
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
311
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
312
        in Cartesian form and a square unit cell.
313
    lexsort : bool, default True
314
        Whether or not to lexicographically order the rows.
315
316
    Returns
317
    -------
318
    pdd : :class:`numpy.ndarray`
319
        The PDD of ``periodic_set`` with enough columns to be
320
        reconstructable.
321
    """
322
323
    motif, cell, _, _ = _get_structure(periodic_set)
324
    dims = cell.shape[0]
325
326
    if dims not in (2, 3):
327
        raise ValueError(
328
            'Reconstructing from PDD is only possible for 2 and 3 dimensions.'
329
        )
330
331
    min_val = diameter(cell) * 2
332
    pdd, _, _ = nearest_neighbours_minval(motif, cell, min_val)
333
334
    if lexsort:
335
        lex_ordering = np.lexsort(np.rot90(pdd))
336
        pdd = pdd[lex_ordering]
337
338
    return pdd
339
340
341
def PPC(periodic_set: PeriodicSetType) -> float:
342
    r"""Return the point packing coefficient (PPC) of ``periodic_set``.
343
344
    The PPC is a constant of any periodic set determining the
345
    asymptotic behaviour of its AMD and PDD. As
346
    :math:`k \rightarrow \infty`, the ratio
347
    :math:`\text{AMD}_k / \sqrt[n]{k}` converges to the PPC, as does any
348
    row of its PDD.
349
350
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n`
351
    dimensions,
352
353
    .. math::
354
355
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
356
357
    where :math:`V_n` is the volume of a unit sphere in :math:`n`
358
    dimensions.
359
360
    Parameters
361
    ----------
362
    periodic_set :
363
        A periodic set represented by a
364
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
365
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
366
        in Cartesian form.
367
368
    Returns
369
    -------
370
    ppc : float
371
        The PPC of ``periodic_set``.
372
    """
373
374
    motif, cell, _, _ = _get_structure(periodic_set)
375
    m, n = motif.shape
376
    t = int(n // 2)
377
378
    if n % 2 == 0:
379
        unit_sphere_vol = (np.pi ** t) / factorial(t)
380
    else:
381
        unit_sphere_vol = (2 * factorial(t) * (4 * np.pi) ** t) / factorial(n)
382
383
    return (np.linalg.det(cell) / (m * unit_sphere_vol)) ** (1.0 / n)
384
385
386
def AMD_estimate(
387
        periodic_set: PeriodicSetType,
388
        k: int
389
) -> npt.NDArray[np.float64]:
390
    r"""Calculate an estimate of AMD based on the PPC.
391
392
    Parameters
393
    ----------
394
    periodic_set :
395
        A periodic set represented by a
396
        :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` or by a
397
        tuple of :class:`numpy.ndarray` s (motif, cell) with coordinates
398
        in Cartesian form.
399
400
    Returns
401
    -------
402
    amd_est : :class:`numpy.ndarray`
403
        An array shape (k, ), where ``amd_est[i]``
404
        :math:`= \text{PPC} \sqrt[n]{k}` in n dimensions.
405
    """
406
407
    motif, cell, _, _ = _get_structure(periodic_set)
408
    n = cell.shape[0]
409
    k_root = np.power(np.arange(1, k + 1, dtype=np.float64), 1.0 / n)
410
    return PPC((motif, cell)) * k_root
411
412
413
def _get_structure(periodic_set: PeriodicSetType) -> Tuple[npt.NDArray, ...]:
414
    """Extract the motif and cell, and if present the asymmetric unit
415
    and scaled multiplicities, from a periodic set. ``periodic_set`` can
416
    be a :class:`amd.PeriodicSet <.periodicset.PeriodicSet>`, or a tuple
417
    of :class:`numpy.ndarray` s (motif, cell).
418
    """
419
420
    asymmetric_unit = None
421
    weights = None
422
423
    if isinstance(periodic_set, PeriodicSet):
424
        motif, cell = periodic_set.motif, periodic_set.cell
425
        asym_unit_inds = periodic_set.asymmetric_unit
426
        wyc_muls = periodic_set.wyckoff_multiplicities
427
        if asym_unit_inds is not None and wyc_muls is not None:
428
            asymmetric_unit = motif[asym_unit_inds]
429
            weights = wyc_muls / np.sum(wyc_muls)
430
    else:
431
        motif, cell = periodic_set
432
433
    if asymmetric_unit is None or weights is None:
434
        asymmetric_unit = motif
435
        weights = np.full((len(motif), ), 1 / len(motif), dtype=np.float64)
436
437
    return motif, cell, asymmetric_unit, weights
438
439
440
def _collapse_into_groups(overlapping: npt.NDArray) -> list:
441
    """Return a list of groups of indices where all indices in the same
442
    group overlap. ``overlapping`` indicates for each pair of items in a
443
    set whether or not the items overlap, in the shape of a condensed
444
    distance matrix.
445
    """
446
447
    overlapping = squareform(overlapping)
448
    group_nums = {}
449
    group = 0
450
    for i, row in enumerate(overlapping):
451
        if i not in group_nums:
452
            group_nums[i] = group
453
            group += 1
454
455
            for j in np.argwhere(row).T[0]:
456
                if j not in group_nums:
457
                    group_nums[j] = group_nums[i]
458
459
    groups = collections.defaultdict(list)
460
    for row_ind, group_num in sorted(group_nums.items()):
461
        groups[group_num].append(row_ind)
462
    groups = list(groups.values())
463
464
    return groups
465