Passed
Push — master ( a4dae9...10a6a8 )
by Daniel
07:04
created

amd.calculate   A

Complexity

Total Complexity 31

Size/Duplication

Total Lines 415
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 31
eloc 126
dl 0
loc 415
rs 9.92
c 0
b 0
f 0

10 Functions

Rating   Name   Duplication   Size   Complexity  
A AMD() 0 45 1
A PDD_to_AMD() 0 15 1
A PDD_reconstructable() 0 34 3
B PDD_finite() 0 71 5
A _extract_motif_cell() 0 20 4
A PPC() 0 38 2
B _collapse_into_groups() 0 23 6
A AMD_estimate() 0 15 1
B PDD() 0 89 7
A AMD_finite() 0 28 1
1
"""Functions for calculating the average minimum distance (AMD) and 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
2
point-wise distance distribution (PDD) isometric invariants of 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
3
periodic crystals and finite sets.
4
"""
5
6
from typing import Union, Tuple
7
import collections
8
9
import numpy as np
10
from scipy.spatial.distance import pdist, squareform
11
12
from .utils import diameter
13
from ._nns import nearest_neighbours, nearest_neighbours_minval
14
from .periodicset import PeriodicSet
15
16
PeriodicSet_or_Tuple = Union[PeriodicSet, Tuple[np.ndarray, np.ndarray]]
17
18
19
def AMD(
20
        periodic_set: PeriodicSet_or_Tuple, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
21
        k: int
22
) -> np.ndarray:
23
    """The AMD of a periodic set (crystal) up to k.
24
25
    Parameters
26
    ----------
27
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of :class:`numpy.ndarray` s
28
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
29
        by a tuple (motif, cell) with coordinates in Cartesian form and a square unit cell.
30
    k : int
31
        Length of the AMD returned; the number of neighbours considered for each atom 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
32
        in the unit cell to make the AMD.
33
34
    Returns
35
    -------
36
    numpy.ndarray
37
        A :class:`numpy.ndarray` shape (k, ), the AMD of ``periodic_set`` up to k.
38
39
    Examples
40
    --------
41
    Make list of AMDs with k = 100 for crystals in data.cif::
42
43
        amds = []
44
        for periodic_set in amd.CifReader('data.cif'):
45
            amds.append(amd.AMD(periodic_set, 100))
46
  
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
47
    Make list of AMDs with k = 10 for crystals in these CSD refcode families::
48
49
        amds = []
50
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
51
            amds.append(amd.AMD(periodic_set, 10))
52
53
    Manually pass a periodic set as a tuple (motif, cell)::
54
55
        # simple cubic lattice
56
        motif = np.array([[0,0,0]])
57
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
58
        cubic_amd = amd.AMD((motif, cell), 100)
59
    """
60
61
    motif, cell, asymmetric_unit, weights = _extract_motif_cell(periodic_set)
62
    pdd, _, _ = nearest_neighbours(motif, cell, asymmetric_unit, k)
63
    return np.average(pdd, axis=0, weights=weights)
64
65
66
def PDD(
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
67
        periodic_set: PeriodicSet_or_Tuple,
68
        k: int,
69
        lexsort: bool = True,
70
        collapse: bool = True,
71
        collapse_tol: float = 1e-4,
72
        return_row_groups: bool = False,
73
) -> np.ndarray:
74
    """The PDD of a periodic set (crystal) up to k.
75
76
    Parameters
77
    ----------
78
    periodic_set : :class:`.periodicset.PeriodicSet`  tuple of :class:`numpy.ndarray` s
79
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
80
        by a tuple (motif, cell) with coordinates in Cartesian form and a square unit cell.
81
    k : int
82
        The returned PDD has k+1 columns, an additional first column for row weights.
83
        k is the number of neighbours considered for each atom in the unit cell 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
84
        to make the PDD.
85
    lexsort : bool, default True
86
        Lexicographically order the rows. Default True.
87
    collapse: bool, default True
88
        Collapse repeated rows (within the tolerance ``collapse_tol``). Default True.
89
    collapse_tol: float, default 1e-4
90
        If two rows have all elements closer than ``collapse_tol``, they are merged and
91
        weights are given to rows in proportion to the number of times they appeared.
92
        Default is 0.0001.
93
    return_row_groups: bool, default False
94
        Return data about which PDD rows correspond to which points.
95
        If True, a tuple is returned ``(pdd, groups)`` where ``groups[i]`` 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
96
        contains the indices of the point(s) corresponding to ``pdd[i]``. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
97
        Note that these indices are for the asymmetric unit of the set, whose
98
        indices in ``periodic_set.motif`` are accessible through 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
99
        ``periodic_set.asymmetric_unit``.
100
101
    Returns
102
    -------
103
    numpy.ndarray
104
        A :class:`numpy.ndarray` with k+1 columns, the PDD of ``periodic_set`` up to k.
105
        The first column contains the weights of rows. If ``return_row_groups`` is True, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
106
        returns a tuple (:class:`numpy.ndarray`, list).
107
108
    Examples
109
    --------
110
    Make list of PDDs with ``k=100`` for crystals in data.cif::
111
112
        pdds = []
113
        for periodic_set in amd.CifReader('data.cif'):
114
            # do not lexicographically order rows
115
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
116
117
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode families::
118
119
        pdds = []
120
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
121
            # do not collapse rows
122
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
123
124
    Manually pass a periodic set as a tuple (motif, cell)::
125
126
        # simple cubic lattice
127
        motif = np.array([[0,0,0]])
128
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
129
        cubic_amd = amd.PDD((motif, cell), 100)
130
    """
131
132
    motif, cell, asymmetric_unit, weights = _extract_motif_cell(periodic_set)
133
    dists, _, _ = nearest_neighbours(motif, cell, asymmetric_unit, k)
134
    groups = [[i] for i in range(len(dists))]
135
136
    if collapse and collapse_tol > 0:
137
        overlapping = pdist(dists, metric='chebyshev') < collapse_tol
138
        if overlapping.any():
139
            groups = _collapse_into_groups(overlapping)
140
            weights = np.array([sum(weights[group]) for group in groups])
141
            dists = np.array([np.average(dists[group], axis=0) for group in groups])
142
143
    pdd = np.hstack((weights[:, None], dists))
144
145
    if lexsort:
146
        lex_ordering = np.lexsort(np.rot90(dists))
147
        if return_row_groups:
148
            groups = [groups[i] for i in lex_ordering]
149
        pdd = pdd[lex_ordering]
150
151
    if return_row_groups:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
152
        return pdd, groups
153
    else:
154
        return pdd
155
156
157
def PDD_to_AMD(pdd: np.ndarray) -> np.ndarray:
158
    """Calculates an AMD from a PDD. Faster than computing both from scratch.
159
160
    Parameters
161
    ----------
162
    pdd : numpy.ndarray
163
        The PDD of a periodic set.
164
165
    Returns
166
    -------
167
    numpy.ndarray
168
        The AMD of the periodic set.
169
    """
170
171
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
172
173
174
def AMD_finite(motif: np.ndarray) -> np.ndarray:
175
    """The AMD of a finite m-point set up to k = m-1.
176
177
    Parameters
178
    ----------
179
    motif : numpy.ndarray
180
        Coordinates of a set of points.
181
182
    Returns
183
    -------
184
    numpy.ndarray
185
        A vector length m-1 (where m is the number of points), the AMD of ``motif``.
186
187
    Examples
188
    --------
189
    The AMD distance (L-infinity) between finite trapezium and kite point sets::
190
191
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
192
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
193
194
        trap_amd = amd.AMD_finite(trapezium)
195
        kite_amd = amd.AMD_finite(kite)
196
197
        l_inf_dist = np.amax(np.abs(trap_amd - kite_amd))
198
    """
199
200
    dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
201
    return np.average(dm, axis=0)
202
203
204
def PDD_finite(
205
        motif: np.ndarray,
206
        lexsort: bool = True,
207
        collapse: bool = True,
208
        collapse_tol: float = 1e-4,
209
        return_row_groups: bool = False,
210
) -> np.ndarray:
211
    """The PDD of a finite m-point set up to k = m-1.
212
213
    Parameters
214
    ----------
215
    motif : numpy.ndarray
216
        Coordinates of a set of points.
217
    lexsort : bool, default True
218
        Whether or not to lexicographically order the rows. Default True.
219
    collapse: bool, default True
220
        Whether or not to collapse repeated rows (within the tolerance ``collapse_tol``). 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
221
        Default True.
222
    collapse_tol: float, default 1e-4
223
        If two rows have all elements closer than ``collapse_tol``, they are merged and
224
        weights are given to rows in proportion to the number of times they appeared.
225
        Default is 0.0001.
226
    return_row_groups: bool, default False
227
        Whether to return data about which PDD rows correspond to which points.
228
        If True, a tuple is returned ``(pdd, groups)`` where ``groups[i]`` 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
229
        contains the indices of the point(s) corresponding to ``pdd[i]``.
230
231
    Returns
232
    -------
233
    numpy.ndarray
234
        A :class:`numpy.ndarray` with m columns (where m is the number of points), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
235
        the PDD of ``motif``. The first column contains the weights of rows.
236
237
    Examples
238
    --------
239
    Find PDD distance between finite trapezium and kite point sets::
240
241
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
242
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
243
244
        trap_pdd = amd.PDD_finite(trapezium)
245
        kite_pdd = amd.PDD_finite(kite)
246
247
        dist = amd.EMD(trap_pdd, kite_pdd)
248
    """
249
250
    dm = squareform(pdist(motif))
251
    m = motif.shape[0]
252
    dists = np.sort(dm, axis=-1)[:, 1:]
253
    weights = np.full((m, ), 1 / m)
254
    groups = [[i] for i in range(len(dists))]
255
256
    if collapse:
257
        overlapping = pdist(dists, metric='chebyshev')
258
        overlapping = overlapping < collapse_tol
259
        if overlapping.any():
260
            groups = _collapse_into_groups(overlapping)
261
            weights = np.array([sum(weights[group]) for group in groups])
262
            dists = np.array([np.average(dists[group], axis=0) for group in groups])
263
264
    pdd = np.hstack((weights[:, None], dists))
265
266
    if lexsort:
267
        lex_ordering = np.lexsort(np.rot90(dists))
268
        groups = [groups[i] for i in lex_ordering]
269
        pdd = pdd[lex_ordering]
270
271
    if return_row_groups:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
272
        return pdd, groups
273
    else:
274
        return pdd
275
276
277
def PDD_reconstructable(
278
        periodic_set: PeriodicSet_or_Tuple,
279
        lexsort: bool = True
280
) -> np.ndarray:
281
    """The PDD of a periodic set with `k` (no of columns) large enough such that
282
    the periodic set can be reconstructed from the PDD.
283
284
    Parameters
285
    ----------
286
    periodic_set : :class:`.periodicset.PeriodicSet`  tuple of :class:`numpy.ndarray` s
287
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
288
        by a tuple (motif, cell) with coordinates in Cartesian form and a square unit cell.
289
    lexsort : bool, default True
290
        Whether or not to lexicographically order the rows. Default True.
291
292
    Returns
293
    -------
294
    numpy.ndarray
295
        An ndarray, the PDD of ``periodic_set`` with enough columns to be reconstructable.
296
    """
297
298
    motif, cell, _, _ = _extract_motif_cell(periodic_set)
299
    dims = cell.shape[0]
300
301
    if dims not in (2, 3):
302
        raise ValueError('Reconstructing from PDD only implemented for 2 and 3 dimensions')
303
304
    min_val = diameter(cell) * 2
305
    pdd = nearest_neighbours_minval(motif, cell, min_val)
306
307
    if lexsort:
308
        pdd = pdd[np.lexsort(np.rot90(pdd))]
309
310
    return pdd
311
312
313
def PPC(periodic_set: PeriodicSet_or_Tuple) -> float:
314
    r"""The point packing coefficient (PPC) of ``periodic_set``.
315
316
    The PPC is a constant of any periodic set determining the
317
    asymptotic behaviour of its AMD and PDD. As :math:`k \rightarrow \infty`, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
318
    the ratio :math:`\text{AMD}_k / \sqrt[n]{k}` converges to the PPC,
319
    as does any row of its PDD.
320
321
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n` dimensions,
322
323
    .. math::
324
325
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
326
327
    where :math:`V_n` is the volume of a unit sphere in :math:`n` dimensions.
328
329
    Parameters
330
    ----------
331
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of
332
        :class:`numpy.ndarray` s (motif, cell) representing the periodic set 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
333
        in Cartesian form.
334
335
    Returns
336
    -------
337
    float
338
        The PPC of ``periodic_set``.
339
    """
340
341
    motif, cell, _, _ = _extract_motif_cell(periodic_set)
342
    m, n = motif.shape
343
    det = np.linalg.det(cell)
344
    t = (n - n % 2) / 2
345
    if n % 2 == 0:
346
        V = (np.pi ** t) / np.math.factorial(t)
347
    else:
348
        V = (2 * np.math.factorial(t) * (4 * np.pi) ** t) / np.math.factorial(n)
349
350
    return (det / (m * V)) ** (1./n)
351
352
353
def AMD_estimate(periodic_set: PeriodicSet_or_Tuple, k: int) -> np.ndarray:
354
    r"""Calculates an estimate of AMD based on the PPC, using the fact that
355
356
    .. math::
357
358
        \lim_{k\rightarrow\infty}\frac{\text{AMD}_k}{\sqrt[n]{k}} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
359
360
    where :math:`U` is the unit cell, :math:`m` is the number of motif points and
361
    :math:`V_n` is the volume of a unit sphere in :math:`n`-dimensional space.
362
    """
363
364
    motif, cell, _, _ = _extract_motif_cell(periodic_set)
365
    n = motif.shape[1]
366
    c = PPC((motif, cell))
367
    return np.array([(x ** (1. / n)) * c for x in range(1, k + 1)])
368
369
370
def _extract_motif_cell(pset: PeriodicSet_or_Tuple):
371
    """`pset` is either a :class:`.periodicset.PeriodicSet`, or
372
    a tuple of ndarrays (motif, cell). If possible, extracts the asymmetric unit
373
    and wyckoff multiplicities.
374
    """
375
376
    if isinstance(pset, PeriodicSet):
377
        motif, cell = pset.motif, pset.cell
378
        if pset.asymmetric_unit is None or pset.wyckoff_multiplicities is None:
379
            asymmetric_unit = motif
380
            weights = np.full((len(motif), ), 1 / len(motif))
381
        else:
382
            asymmetric_unit = pset.motif[pset.asymmetric_unit]
383
            weights = pset.wyckoff_multiplicities / np.sum(pset.wyckoff_multiplicities)
384
    else:
385
        motif, cell = pset
386
        asymmetric_unit = motif
387
        weights = np.full((len(motif), ), 1 / len(motif))
388
389
    return motif, cell, asymmetric_unit, weights
390
391
392
def _collapse_into_groups(overlapping):
393
    """The vector `overlapping` indicates for each pair of items in a set whether 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
394
    or not the items overlap, in the shape of a condensed distance matrix. Returns
395
    a list of groups of indices where all items in the same group overlap."""
396
397
    overlapping = squareform(overlapping)
398
    group_nums = {} # row_ind: group number
399
    group = 0
400
    for i, row in enumerate(overlapping):
401
        if i not in group_nums:
402
            group_nums[i] = group
403
            group += 1
404
405
            for j in np.argwhere(row).T[0]:
406
                if j not in group_nums:
407
                    group_nums[j] = group_nums[i]
408
409
    groups = collections.defaultdict(list)
410
    for row_ind, group_num in sorted(group_nums.items()):
411
        groups[group_num].append(row_ind)
412
    groups = list(groups.values())
413
414
    return groups
415