Passed
Push — master ( e06329...c0aa4a )
by Daniel
03:48
created

amd.calculate.PDF()   A

Complexity

Conditions 3

Size

Total Lines 35
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 35
rs 9.65
c 0
b 0
f 0
cc 3
nop 2
1
"""Functions for calculating AMDs and PDDs (and SDDs) of periodic and finite sets.
2
"""
3
4
from typing import Union, Tuple
5
import itertools
6
import collections
7
8
import numpy as np
9
import scipy.spatial
10
import scipy.special
11
12
from ._nearest_neighbours import nearest_neighbours, nearest_neighbours_minval, generate_concentric_cloud
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
13
from .periodicset import PeriodicSet
14
from .utils import diameter
15
16
PSET_OR_TUPLE = Union[PeriodicSet, Tuple[np.ndarray, np.ndarray]]
17
18
19
def AMD(periodic_set: PSET_OR_TUPLE, k: int) -> np.ndarray:
20
    """The AMD up to `k` of a periodic set.
21
22
    Parameters
23
    ----------
24
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of ndarrays
25
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
26
        by a tuple (motif, cell) with coordinates in Cartesian form.
27
    k : int
28
        Length of AMD returned.
29
30
    Returns
31
    -------
32
    ndarray
33
        An ndarray of shape (k,), the AMD of ``periodic_set`` up to `k`.
34
35
    Examples
36
    --------
37
    Make list of AMDs with ``k=100`` for crystals in mycif.cif::
38
39
        amds = []
40
        for periodic_set in amd.CifReader('mycif.cif'):
41
            amds.append(amd.AMD(periodic_set, 100))
42
  
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
43
    Make list of AMDs with ``k=10`` for crystals in these CSD refcode families::
44
45
        amds = []
46
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
47
            amds.append(amd.AMD(periodic_set, 10))
48
49
    Manually pass a periodic set as a tuple (motif, cell)::
50
51
        # simple cubic lattice
52
        motif = np.array([[0,0,0]])
53
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
54
        cubic_amd = amd.AMD((motif, cell), 100)
55
    """
56
57
    motif, cell, asymmetric_unit, multiplicities = _extract_motif_and_cell(periodic_set)
58
    pdd, _, _ = nearest_neighbours(motif, cell, k, asymmetric_unit=asymmetric_unit)
59
    return np.average(pdd, axis=0, weights=multiplicities)
60
61
def PDD(
62
        periodic_set: PSET_OR_TUPLE,
63
        k: int,
64
        lexsort: bool = True,
65
        collapse: bool = True,
66
        collapse_tol: float = 1e-4
67
) -> np.ndarray:
68
    """The PDD up to `k` of a periodic set.
69
70
    Parameters
71
    ----------
72
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of ndarrays
73
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
74
        by a tuple (motif, cell) with coordinates in Cartesian form.
75
    k : int
76
        Number of columns in the PDD (the returned matrix has an additional first
77
        column containing weights).
78
    lexsort : bool, optional
79
        Whether or not to lexicographically order the rows. Default True.
80
    collapse: bool, optional
81
        Whether or not to collapse identical rows (within a tolerance). Default True.
82
    collapse_tol: float
83
        If two rows have all entries closer than collapse_tol, they get collapsed.
84
        Default is 1e-4.
85
86
    Returns
87
    -------
88
    ndarray
89
        An ndarray with k+1 columns, the PDD of ``periodic_set`` up to `k`.
90
91
    Examples
92
    --------
93
    Make list of PDDs with ``k=100`` for crystals in mycif.cif::
94
95
        pdds = []
96
        for periodic_set in amd.CifReader('mycif.cif'):
97
            # do not lexicographically order rows
98
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
99
100
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode families::
101
102
        pdds = []
103
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
104
            # do not collapse rows
105
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
106
107
    Manually pass a periodic set as a tuple (motif, cell)::
108
109
        # simple cubic lattice
110
        motif = np.array([[0,0,0]])
111
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
112
        cubic_amd = amd.PDD((motif, cell), 100)
113
    """
114
115
    motif, cell, asymmetric_unit, multiplicities = _extract_motif_and_cell(periodic_set)
116
    dists, _, _ = nearest_neighbours(motif, cell, k, asymmetric_unit=asymmetric_unit)
117
118
    if multiplicities is None:
119
        weights = np.full((motif.shape[0], ), 1 / motif.shape[0])
120
    else:
121
        weights = multiplicities / np.sum(multiplicities)
122
123
    if collapse:
124
        weights, dists = _collapse_rows(weights, dists, collapse_tol)
125
126
    pdd = np.hstack((weights[:, None], dists))
127
128
    if lexsort:
129
        pdd = pdd[np.lexsort(np.rot90(dists))]
130
131
    return pdd
132
133
134
def PDD_to_AMD(pdd: np.ndarray) -> np.ndarray:
135
    """Calculates AMD from a PDD. Faster than computing both from scratch.
136
137
    Parameters
138
    ----------
139
    pdd : np.ndarray
140
        The PDD of a periodic set.
141
142
    Returns
143
    -------
144
    ndarray
145
        The AMD of the periodic set.
146
    """
147
148
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
149
150
151
def AMD_finite(motif: np.ndarray) -> np.ndarray:
152
    """The AMD of a finite point set (up to k = `len(motif) - 1`).
153
154
    Parameters
155
    ----------
156
    motif : ndarray
157
        Cartesian coordinates of points in a set. Shape (n_points, dimensions)
158
159
    Returns
160
    -------
161
    ndarray
162
        An vector length len(motif) - 1, the AMD of ``motif``.
163
164
    Examples
165
    --------
166
    Find AMD distance between finite trapezium and kite point sets::
167
168
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
169
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
170
171
        trap_amd = amd.AMD_finite(trapezium)
172
        kite_amd = amd.AMD_finite(kite)
173
174
        dist = amd.AMD_pdist(trap_amd, kite_amd)
175
    """
176
177
    dm = np.sort(scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(motif)), axis=-1)[:, 1:]
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
178
    return np.average(dm, axis=0)
179
180
181
def PDD_finite(
182
        motif: np.ndarray,
183
        lexsort: bool = True,
184
        collapse: bool = True,
185
        collapse_tol: float = 1e-4
186
) -> np.ndarray:
187
    """The PDD of a finite point set (up to k = `len(motif) - 1`).
188
189
    Parameters
190
    ----------
191
    motif : ndarray
192
        Cartesian coordinates of points in a set. Shape (n_points, dimensions)
193
    lexsort : bool, optional
194
        Whether or not to lexicographically order the rows. Default True.
195
    collapse: bool, optional
196
        Whether or not to collapse identical rows (within a tolerance). Default True.
197
    collapse_tol: float
198
        If two rows have all entries closer than collapse_tol, they get collapsed.
199
        Default is 1e-4.
200
201
    Returns
202
    -------
203
    ndarray
204
        An ndarray with len(motif) columns, the PDD of ``motif``.
205
206
    Examples
207
    --------
208
    Find PDD distance between finite trapezium and kite point sets::
209
210
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
211
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
212
213
        trap_pdd = amd.PDD_finite(trapezium)
214
        kite_pdd = amd.PDD_finite(kite)
215
216
        dist = amd.emd(trap_pdd, kite_pdd)
217
    """
218
219
    dm = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(motif))
220
    m = motif.shape[0]
221
    dists = np.sort(dm, axis=-1)[:, 1:]
222
    weights = np.full((m, ), 1 / m)
223
224
    if collapse:
225
        weights, dists = _collapse_rows(weights, dists, collapse_tol)
226
227
    pdd = np.hstack((weights[:, None], dists))
228
229
    if lexsort:
230
        pdd = pdd[np.lexsort(np.rot90(dists))]
231
232
    return pdd
233
234
235
def SDD(
236
        motif: np.ndarray,
237
        order: int = 1,
238
        lexsort: bool = True,
239
        collapse: bool = True,
240
        collapse_tol: float = 1e-4):
241
    """The SSD (simplex-wise distance distribution) of a finite point set,
242
    with `len(motif) - 1` columns. The SDD with order h considers h-sized collection
243
    of points in the motif; the first-order SDD is equivalent to the PDD for finite sets.
244
245
    Parameters
246
    ----------
247
    motif : ndarray
248
        Cartesian coordinates of points in a set. Shape (n_points, dimensions)
249
    order : int
250
        Order of the SDD, default 1. See papers for a description of higher-order SDDs.
251
    lexsort : bool, optional
252
        Whether or not to lexicographically order the rows. Default True.
253
    collapse: bool, optional
254
        Whether or not to collapse identical rows (within a tolerance). Default True.
255
    collapse_tol: float
256
        If two rows have all entries closer than collapse_tol, they get collapsed.
257
        Default is 1e-4.
258
259
    Returns
260
    -------
261
    tuple of ndarrays
262
        The h-order SDD of ``motif``. A tuple of 3 arrays is returned,
263
        ``weights``, ``dist`` and ``sdd``. If order=1, dist is None.
264
265
    Examples
266
    --------
267
    Find the SDD of the trapezium and kite point sets::
268
269
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
270
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
271
272
        trap_sdd = amd.SDD(trapezium, order=2)
273
        kite_sdd = amd.SDD(kite)
274
    """
275
276
    m = motif.shape[0]
277
278
    if order == 1:
279
        dm = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(motif))
280
        sdd = np.sort(dm, axis=-1)[:, 1:]
281
        weights = np.full((m, ), 1 / m)
282
283
        if collapse:
284
            weights, sdd = _collapse_rows(weights, sdd, collapse_tol)
285
286
        if lexsort:
287
            sorted_inds = np.lexsort(np.rot90(sdd))
288
            weights, sdd = weights[sorted_inds], sdd[sorted_inds]
289
290
        return weights, None, sdd
291
292
    if m <= order:
293
        raise ValueError(f'The higher order SDD is only defined when the order ({order}) is smaller than the number of points ({motif.shape[0]})')
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (146/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
294
295
    dm = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(motif))
296
    dist = []
297
    sdd = []
298
299
    for points in itertools.combinations(range(m), order):
300
        points = list(points)
301
        remove_rows = np.full((m, ), True)
302
        np.put(remove_rows, points, False)
303
        unsorted_row = np.sort(dm[remove_rows][:, points], axis=-1)
304
        sorted_row = unsorted_row[np.lexsort(np.rot90(unsorted_row))]
305
        sdd.append(sorted_row)
306
307
        if order == 2:
308
            dist.append(dm[points[0], points[1]])
309
        else:
310
            dists = dm[points][:, points]
311
            dists = np.sort(dists, axis=-1)[:, 1:]
312
            pdd_finite = dists[np.lexsort(np.rot90(dists))]
313
            dist.append(pdd_finite)
314
315
    sdd, dist = np.array(sdd), np.array(dist)
316
    n_rows = scipy.special.comb(m, order, exact=True)
317
    weights = np.full((n_rows, ), 1 / n_rows)
318
319
    if collapse:
320
        dist_diffs = np.abs(dist[:, None] - dist) <= collapse_tol
321
322
        if dist.ndim == 1:
323
            dist_overlapping = dist_diffs
324
        else:
325
            dist_overlapping = np.all(np.all(dist_diffs, axis=-1), axis=-1)
326
327
        sdd_overlapping = np.all(np.all(np.abs(sdd[:, None] - sdd) <= collapse_tol, axis=-1), axis=-1)
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (102/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
328
        overlapping = np.logical_and(sdd_overlapping, dist_overlapping)
329
        res = _group_overlapping_and_sum_weights(weights, overlapping)
330
        if res is not None:
331
            weights, dist, sdd = res[0], dist[res[1]], sdd[res[1]]
332
333
    if lexsort:
334
        if order == 2:
335
            flat_sdd = np.hstack((dist[:, None], sdd.reshape((sdd.shape[0], sdd.shape[1] * sdd.shape[2]))))
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (107/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
336
            args = np.lexsort(np.rot90(flat_sdd))
337
        else:
338
            flat_dist = dist.reshape((dist.shape[0], dist.shape[1] * dist.shape[2]))
339
            flat_sdd = sdd.reshape((sdd.shape[0], sdd.shape[1] * sdd.shape[2]))
340
            args = np.lexsort(np.rot90(np.hstack((flat_dist, flat_sdd))))
341
        weights, dist, sdd = weights[args], dist[args], sdd[args]
342
343
    return weights, dist, sdd
344
345
346
def PDD_reconstructable(
347
        periodic_set: PSET_OR_TUPLE,
348
        lexsort: bool = True
349
) -> np.ndarray:
350
    """The PDD of a periodic set with `k` (no of columns) large enough such that
351
    the periodic set can be reconstructed from the PDD.
352
353
    Parameters
354
    ----------
355
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of ndarrays
356
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
357
        by a tuple (motif, cell) with coordinates in Cartesian form.
358
    k : int
359
        Number of columns in the PDD, plus one for the first column of weights.
360
    order : int
361
        Order of the PDD, default 1. See papers for a description of higher-order PDDs.
362
    lexsort : bool, optional
363
        Whether or not to lexicographically order the rows. Default True.
364
    collapse: bool, optional
365
        Whether or not to collapse identical rows (within a tolerance). Default True.
366
    collapse_tol: float
367
        If two rows have all entries closer than collapse_tol, they get collapsed.
368
        Default is 1e-4.
369
370
    Returns
371
    -------
372
    ndarray
373
        An ndarray with k+1 columns, the PDD of ``periodic_set`` up to `k`.
374
375
    Examples
376
    --------
377
    Make list of PDDs with ``k=100`` for crystals in mycif.cif::
378
379
        pdds = []
380
        for periodic_set in amd.CifReader('mycif.cif'):
381
            # do not lexicographically order rows
382
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
383
384
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode families::
385
386
        pdds = []
387
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
388
            # do not collapse rows
389
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
390
391
    Manually pass a periodic set as a tuple (motif, cell)::
392
393
        # simple cubic lattice
394
        motif = np.array([[0,0,0]])
395
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
396
        cubic_amd = amd.PDD((motif, cell), 100)
397
    """
398
399
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
400
    dims = cell.shape[0]
401
402
    if dims not in (2, 3):
403
        raise ValueError('Reconstructing from PDD only implemented for 2 and 3 dimensions')
404
405
    min_val = diameter(cell) * 2
406
    pdd = nearest_neighbours_minval(motif, cell, min_val)
407
408
    if lexsort:
409
        pdd = pdd[np.lexsort(np.rot90(pdd))]
410
411
    return pdd
412
413
414
def PDF(periodic_set, cutoff_r):
415
    """The PDF (pair distribution function) of a periodic set up to a cutoff
416
    radius r. This is a 1D vector of sorted distances between all points
417
    pairwise (where at least one of two points is in the motif).
418
419
    Parameters
420
    ----------
421
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of ndarrays
422
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
423
        by a tuple (motif, cell) with coordinates in Cartesian form.
424
    cutoff_r : int
425
        Cutoff radius for distances to find.
426
427
    Returns
428
    -------
429
    ndarray
430
        A 1D ndarray of distances, the PDF of periodic_set.
431
    """
432
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
433
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
434
    motif, cell = periodic_set
435
    generator = generate_concentric_cloud(motif, cell)
436
437
    cloud = []
438
    while True:
439
        next_layer = np.vstack((next(generator), next(generator)))
440
        cloud.append(next_layer)
441
        if np.all(scipy.spatial.distance.cdist(motif, next_layer) > cutoff_r):
442
            break
443
    cloud.append(next(generator))
444
    cloud = np.concatenate(cloud)
445
446
    pdf = np.sort(scipy.spatial.distance.cdist(motif, cloud).flatten())
447
    pdf = pdf[(pdf <= cutoff_r) & (pdf != 0)]
448
    return pdf
449
450
451
def PPC(periodic_set: PSET_OR_TUPLE) -> float:
452
    r"""The point packing coefficient (PPC) of ``periodic_set``.
453
454
    The PPC is a constant of any periodic set determining the
455
    asymptotic behaviour of its AMD or PDD as :math:`k \rightarrow \infty`.
456
457
    As :math:`k \rightarrow \infty`, the ratio :math:`\text{AMD}_k / \sqrt[n]{k}`
458
    approaches the PPC (as does any row of its PDD).
459
460
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n` dimensions,
461
462
    .. math::
463
464
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
465
466
    where :math:`V_n` is the volume of a unit sphere in :math:`n` dimensions.
467
468
    Parameters
469
    ----------
470
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of
471
        ndarrays (motif, cell) representing the periodic set in Cartesian form.
472
473
    Returns
474
    -------
475
    float
476
        The PPC of ``periodic_set``.
477
    """
478
479
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
480
    m, n = motif.shape
481
    det = np.linalg.det(cell)
482
    t = (n - n % 2) / 2
483
    if n % 2 == 0:
484
        V = (np.pi ** t) / np.math.factorial(t)
485
    else:
486
        V = (2 * np.math.factorial(t) * (4 * np.pi) ** t) / np.math.factorial(n)
487
488
    return (det / (m * V)) ** (1./n)
489
490
491
def AMD_estimate(periodic_set: PSET_OR_TUPLE, k: int) -> np.ndarray:
492
    r"""Calculates an estimate of AMD based on the PPC, using the fact that
493
494
    .. math::
495
496
        \lim_{k\rightarrow\infty}\frac{\text{AMD}_k}{\sqrt[n]{k}} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (105/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
497
498
    where :math:`U` is the unit cell, :math:`m` is the number of motif points and
499
    :math:`V_n` is the volume of a unit sphere in :math:`n`-dimensional space.
500
    """
501
502
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
503
    n = motif.shape[1]
504
    c = PPC((motif, cell))
505
    return np.array([(x ** (1. / n)) * c for x in range(1, k + 1)])
506
507
508
def _extract_motif_and_cell(periodic_set: PSET_OR_TUPLE):
509
    """`periodic_set` is either a :class:`.periodicset.PeriodicSet`, or
510
    a tuple of ndarrays (motif, cell). If possible, extracts the asymmetric unit
511
    and wyckoff multiplicities and returns them, otherwise returns None.
512
    """
513
514
    asymmetric_unit, multiplicities = None, None
515
516
    if isinstance(periodic_set, PeriodicSet):
517
        motif, cell = periodic_set.motif, periodic_set.cell
518
519
        if 'asymmetric_unit' in periodic_set.tags and 'wyckoff_multiplicities' in periodic_set.tags:
520
            asymmetric_unit = periodic_set.asymmetric_unit
521
            multiplicities = periodic_set.wyckoff_multiplicities
522
523
    elif isinstance(periodic_set, np.ndarray):
524
        motif, cell = periodic_set, None
525
    else:
526
        motif, cell = periodic_set[0], periodic_set[1]
527
528
    return motif, cell, asymmetric_unit, multiplicities
529
530
531
def _collapse_rows(weights, dists, collapse_tol):
532
    """Given a vector `weights`, matrix `dists` and tolerance `collapse_tol`, collapse
533
    the identical rows of dists (if all entries in a row are within  `collapse_tol`)
534
    and collapse the same entires of `weights` (adding entries that merge).
535
    """
536
537
    diffs = np.abs(dists[:, None] - dists)
538
    overlapping = np.all(diffs <= collapse_tol, axis=-1)
539
540
    res = _group_overlapping_and_sum_weights(weights, overlapping)
541
    if res is not None:
542
        weights = res[0]
543
        dists = dists[res[1]]
544
545
    return weights, dists
546
547
548
def _group_overlapping_and_sum_weights(weights, overlapping):
0 ignored issues
show
Unused Code introduced by
Either all return statements in a function should return an expression, or none of them should.
Loading history...
549
    # I hate this solution, but I can't seem to make anything cleverer work.
550
    if np.triu(overlapping, 1).any():
551
        groups = {}
552
        group = 0
553
        for i, row in enumerate(overlapping):
554
            if i not in groups:
555
                groups[i] = group
556
                group += 1
557
558
            for j in np.argwhere(row).T[0]:
559
                groups[j] = groups[i]
560
561
        groupings = collections.defaultdict(list)
562
        for key, val in sorted(groups.items()):
563
            groupings[val].append(key)
564
565
        weights_ = []
566
        keep_inds = []
567
        for inds in groupings.values():
568
            keep_inds.append(inds[0])
569
            weights_.append(np.sum(weights[inds]))
570
        weights = np.array(weights_)
571
572
        return weights, keep_inds
573