Passed
Push — master ( 4570af...abe023 )
by Daniel
03:02
created

amd.calculate.PDD_finite()   B

Complexity

Conditions 5

Size

Total Lines 72
Code Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 72
rs 8.7653
c 0
b 0
f 0
cc 5
nop 5

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Functions for calculating the average minimum distance (AMD) and 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
2
point-wise distance distribution (PDD) isometric invariants of 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
3
periodic crystals and finite sets.
4
"""
5
6
from typing import Union, Tuple
7
import collections
8
9
import numpy as np
10
from scipy.spatial.distance import pdist, squareform
11
12
from . import utils
13
from ._nearest_neighbours import nearest_neighbours, nearest_neighbours_minval
14
from .periodicset import PeriodicSet
15
16
PeriodicSet_or_Tuple = Union[PeriodicSet, Tuple[np.ndarray, np.ndarray]]
17
18
19
def AMD(
20
        periodic_set: PeriodicSet_or_Tuple, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
21
        k: int
22
) -> np.ndarray:
23
    """The AMD of a periodic set (crystal) up to k.
24
25
    Parameters
26
    ----------
27
    periodic_set : :class:`.periodicset.PeriodicSet`  tuple of :class:`numpy.ndarray` s
28
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
29
        by a tuple (motif, cell) with coordinates in Cartesian form and a square unit cell.
30
    k : int
31
        Length of the AMD returned; the number of neighbours considered for each atom 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
32
        in the unit cell to make the AMD.
33
34
    Returns
35
    -------
36
    numpy.ndarray
37
        A :class:`numpy.ndarray` shape (k, ), the AMD of ``periodic_set`` up to k.
38
39
    Examples
40
    --------
41
    Make list of AMDs with k = 100 for crystals in mycif.cif::
42
43
        amds = []
44
        for periodic_set in amd.CifReader('mycif.cif'):
45
            amds.append(amd.AMD(periodic_set, 100))
46
  
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
47
    Make list of AMDs with k = 10 for crystals in these CSD refcode families::
48
49
        amds = []
50
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
51
            amds.append(amd.AMD(periodic_set, 10))
52
53
    Manually pass a periodic set as a tuple (motif, cell)::
54
55
        # simple cubic lattice
56
        motif = np.array([[0,0,0]])
57
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
58
        cubic_amd = amd.AMD((motif, cell), 100)
59
    """
60
61
    motif, cell, asymmetric_unit, multiplicities = _extract_motif_and_cell(periodic_set)
62
    pdd, _, _ = nearest_neighbours(motif, cell, k, asymmetric_unit=asymmetric_unit)
63
    return np.average(pdd, axis=0, weights=multiplicities)
64
65
66
def PDD(
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
67
        periodic_set: PeriodicSet_or_Tuple,
68
        k: int,
69
        lexsort: bool = True,
70
        collapse: bool = True,
71
        collapse_tol: float = 1e-4,
72
        return_row_groups: bool = False,
73
) -> np.ndarray:
74
    """The PDD of a periodic set (crystal) up to k.
75
76
    Parameters
77
    ----------
78
    periodic_set : :class:`.periodicset.PeriodicSet`  tuple of :class:`numpy.ndarray` s
79
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
80
        by a tuple (motif, cell) with coordinates in Cartesian form and a square unit cell.
81
    k : int
82
        The returned PDD has k+1 columns, an additional first column for row weights.
83
        k is the number of neighbours considered for each atom in the unit cell 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
84
        to make the PDD.
85
    lexsort : bool, optional
86
        Lexicographically order the rows. Default True.
87
    collapse: bool, optional
88
        Collapse repeated rows (within the tolerance ``collapse_tol``). Default True.
89
    collapse_tol: float, optional
90
        If two rows have all elements closer than ``collapse_tol``, they are merged and
91
        weights are given to rows in proportion to the number of times they appeared.
92
        Default is 0.0001.
93
    return_row_groups: bool, optional
94
        Return data about which PDD rows correspond to which points.
95
        If True, a tuple is returned ``(pdd, groups)`` where ``groups[i]`` 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
96
        contains the indices of the point(s) corresponding to ``pdd[i]``. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
97
        Note that these indices are for the asymmetric unit of the set, whose
98
        indices in ``periodic_set.motif`` are accessible through 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
99
        ``periodic_set.asymmetric_unit``.
100
101
    Returns
102
    -------
103
    numpy.ndarray
104
        A :class:`numpy.ndarray` with k+1 columns, the PDD of ``periodic_set`` up to k.
105
        The first column contains the weights of rows. If ``return_row_groups`` is True, 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
106
        returns a tuple (:class:`numpy.ndarray`, list).
107
108
    Examples
109
    --------
110
    Make list of PDDs with ``k=100`` for crystals in mycif.cif::
111
112
        pdds = []
113
        for periodic_set in amd.CifReader('mycif.cif'):
114
            # do not lexicographically order rows
115
            pdds.append(amd.PDD(periodic_set, 100, lexsort=False))
116
117
    Make list of PDDs with ``k=10`` for crystals in these CSD refcode families::
118
119
        pdds = []
120
        for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True):
121
            # do not collapse rows
122
            pdds.append(amd.PDD(periodic_set, 10, collapse=False))
123
124
    Manually pass a periodic set as a tuple (motif, cell)::
125
126
        # simple cubic lattice
127
        motif = np.array([[0,0,0]])
128
        cell = np.array([[1,0,0], [0,1,0], [0,0,1]])
129
        cubic_amd = amd.PDD((motif, cell), 100)
130
    """
131
132
    motif, cell, asymmetric_unit, multiplicities = _extract_motif_and_cell(periodic_set)
133
    dists, _, _ = nearest_neighbours(motif, cell, k, asymmetric_unit=asymmetric_unit)
134
    groups = [[i] for i in range(len(dists))]
135
136
    if multiplicities is None:
137
        weights = np.full((motif.shape[0], ), 1 / motif.shape[0])
138
    else:
139
        weights = multiplicities / np.sum(multiplicities)
140
141
    if collapse:
142
        overlapping = pdist(dists, metric='chebyshev')
143
        overlapping = overlapping < collapse_tol
144
        if overlapping.any():
145
            groups = _collapse_into_groups(overlapping)
146
            weights = np.array([sum(weights[group]) for group in groups])
147
            ordering = [group[0] for group in groups]
148
            dists = dists[ordering]
149
150
    pdd = np.hstack((weights[:, None], dists))
151
152
    if lexsort:
153
        lex_ordering = np.lexsort(np.rot90(dists))
154
        groups = [groups[i] for i in lex_ordering]
155
        pdd = pdd[lex_ordering]
156
157
    if return_row_groups:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
158
        return pdd, groups
159
    else:
160
        return pdd
161
162
163
def PDD_to_AMD(pdd: np.ndarray) -> np.ndarray:
164
    """Calculates an AMD from a PDD. Faster than computing both from scratch.
165
166
    Parameters
167
    ----------
168
    pdd : numpy.ndarray
169
        The PDD of a periodic set.
170
171
    Returns
172
    -------
173
    numpy.ndarray
174
        The AMD of the periodic set.
175
    """
176
177
    return np.average(pdd[:, 1:], weights=pdd[:, 0], axis=0)
178
179
180
def AMD_finite(motif: np.ndarray) -> np.ndarray:
181
    """The AMD of a finite m-point set up to k = m-1.
182
183
    Parameters
184
    ----------
185
    motif : numpy.ndarray
186
        Coordinates of a set of points.
187
188
    Returns
189
    -------
190
    numpy.ndarray
191
        A vector length m-1 (where m is the number of points), the AMD of ``motif``.
192
193
    Examples
194
    --------
195
    The AMD distance (L-infinity) between finite trapezium and kite point sets::
196
197
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
198
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
199
200
        trap_amd = amd.AMD_finite(trapezium)
201
        kite_amd = amd.AMD_finite(kite)
202
203
        l_inf_dist = np.amax(np.abs(trap_amd - kite_amd))
204
    """
205
206
    dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:]
207
    return np.average(dm, axis=0)
208
209
210
def PDD_finite(
211
        motif: np.ndarray,
212
        lexsort: bool = True,
213
        collapse: bool = True,
214
        collapse_tol: float = 1e-4,
215
        return_row_groups: bool = False,
216
) -> np.ndarray:
217
    """The PDD of a finite m-point set up to k = m-1.
218
219
    Parameters
220
    ----------
221
    motif : numpy.ndarray
222
        Coordinates of a set of points.
223
    lexsort : bool, optional
224
        Whether or not to lexicographically order the rows. Default True.
225
    collapse: bool, optional
226
        Whether or not to collapse repeated rows (within the tolerance ``collapse_tol``). 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
227
        Default True.
228
    collapse_tol: float
229
        If two rows have all elements closer than ``collapse_tol``, they are merged and
230
        weights are given to rows in proportion to the number of times they appeared.
231
        Default is 0.0001.
232
    return_row_groups: bool, optional
233
        Whether to return data about which PDD rows correspond to which points.
234
        If True, a tuple is returned ``(pdd, groups)`` where ``groups[i]`` 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
235
        contains the indices of the point(s) corresponding to ``pdd[i]``.
236
237
    Returns
238
    -------
239
    numpy.ndarray
240
        A :class:`numpy.ndarray` with m columns (where m is the number of points), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
241
        the PDD of ``motif``. The first column contains the weights of rows.
242
243
    Examples
244
    --------
245
    Find PDD distance between finite trapezium and kite point sets::
246
247
        trapezium = np.array([[0,0],[1,1],[3,1],[4,0]])
248
        kite      = np.array([[0,0],[1,1],[1,-1],[4,0]])
249
250
        trap_pdd = amd.PDD_finite(trapezium)
251
        kite_pdd = amd.PDD_finite(kite)
252
253
        dist = amd.EMD(trap_pdd, kite_pdd)
254
    """
255
256
    dm = squareform(pdist(motif))
257
    m = motif.shape[0]
258
    dists = np.sort(dm, axis=-1)[:, 1:]
259
    weights = np.full((m, ), 1 / m)
260
    groups = [[i] for i in range(len(dists))]
261
262
    if collapse:
263
        overlapping = pdist(dists, metric='chebyshev')
264
        overlapping = overlapping < collapse_tol
265
        if overlapping.any():
266
            groups = _collapse_into_groups(overlapping)
267
            weights = np.array([sum(weights[group]) for group in groups])
268
            ordering = [group[0] for group in groups]
269
            dists = dists[ordering]
270
271
    pdd = np.hstack((weights[:, None], dists))
272
273
    if lexsort:
274
        lex_ordering = np.lexsort(np.rot90(dists))
275
        groups = [groups[i] for i in lex_ordering]
276
        pdd = pdd[lex_ordering]
277
278
    if return_row_groups:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
279
        return pdd, groups
280
    else:
281
        return pdd
282
283
284
def PDD_reconstructable(
285
        periodic_set: PeriodicSet_or_Tuple,
286
        lexsort: bool = True
287
) -> np.ndarray:
288
    """The PDD of a periodic set with `k` (no of columns) large enough such that
289
    the periodic set can be reconstructed from the PDD.
290
291
    Parameters
292
    ----------
293
    periodic_set : :class:`.periodicset.PeriodicSet`  tuple of :class:`numpy.ndarray` s
294
        A periodic set represented by a :class:`.periodicset.PeriodicSet` or
295
        by a tuple (motif, cell) with coordinates in Cartesian form and a square unit cell.
296
    lexsort : bool, optional
297
        Whether or not to lexicographically order the rows. Default True.
298
299
    Returns
300
    -------
301
    numpy.ndarray
302
        An ndarray, the PDD of ``periodic_set`` with enough columns to be reconstructable.
303
    """
304
305
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
306
    dims = cell.shape[0]
307
308
    if dims not in (2, 3):
309
        raise ValueError('Reconstructing from PDD only implemented for 2 and 3 dimensions')
310
311
    min_val = utils.diameter(cell) * 2
312
    pdd = nearest_neighbours_minval(motif, cell, min_val)
313
314
    if lexsort:
315
        pdd = pdd[np.lexsort(np.rot90(pdd))]
316
317
    return pdd
318
319
320
def PPC(periodic_set: PeriodicSet_or_Tuple) -> float:
321
    r"""The point packing coefficient (PPC) of ``periodic_set``.
322
323
    The PPC is a constant of any periodic set determining the
324
    asymptotic behaviour of its AMD and PDD as :math:`k \rightarrow \infty`.
325
326
    As :math:`k \rightarrow \infty`, the ratio :math:`\text{AMD}_k / \sqrt[n]{k}`
327
    approaches the PPC (as does any row of its PDD).
328
329
    For a unit cell :math:`U` and :math:`m` motif points in :math:`n` dimensions,
330
331
    .. math::
332
333
        \text{PPC} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
334
335
    where :math:`V_n` is the volume of a unit sphere in :math:`n` dimensions.
336
337
    Parameters
338
    ----------
339
    periodic_set : :class:`.periodicset.PeriodicSet` or tuple of
340
        ndarrays (motif, cell) representing the periodic set in Cartesian form.
341
342
    Returns
343
    -------
344
    float
345
        The PPC of ``periodic_set``.
346
    """
347
348
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
349
    m, n = motif.shape
350
    det = np.linalg.det(cell)
351
    t = (n - n % 2) / 2
352
    if n % 2 == 0:
353
        V = (np.pi ** t) / np.math.factorial(t)
354
    else:
355
        V = (2 * np.math.factorial(t) * (4 * np.pi) ** t) / np.math.factorial(n)
356
357
    return (det / (m * V)) ** (1./n)
358
359
360
def AMD_estimate(periodic_set: PeriodicSet_or_Tuple, k: int) -> np.ndarray:
361
    r"""Calculates an estimate of AMD based on the PPC, using the fact that
362
363
    .. math::
364
365
        \lim_{k\rightarrow\infty}\frac{\text{AMD}_k}{\sqrt[n]{k}} = \sqrt[n]{\frac{\text{Vol}[U]}{m V_n}}
366
367
    where :math:`U` is the unit cell, :math:`m` is the number of motif points and
368
    :math:`V_n` is the volume of a unit sphere in :math:`n`-dimensional space.
369
    """
370
371
    motif, cell, _, _ = _extract_motif_and_cell(periodic_set)
372
    n = motif.shape[1]
373
    c = PPC((motif, cell))
374
    return np.array([(x ** (1. / n)) * c for x in range(1, k + 1)])
375
376
377
def _extract_motif_and_cell(periodic_set: PeriodicSet_or_Tuple):
378
    """`periodic_set` is either a :class:`.periodicset.PeriodicSet`, or
379
    a tuple of ndarrays (motif, cell). If possible, extracts the asymmetric unit
380
    and wyckoff multiplicities and returns them, otherwise returns None.
381
    """
382
383
    asymmetric_unit, multiplicities = None, None
384
385
    if isinstance(periodic_set, PeriodicSet):
386
        motif, cell = periodic_set.motif, periodic_set.cell
387
388
        if 'asymmetric_unit' in periodic_set.tags and 'wyckoff_multiplicities' in periodic_set.tags:
389
            asymmetric_unit = periodic_set.asymmetric_unit
390
            multiplicities = periodic_set.wyckoff_multiplicities
391
392
    elif isinstance(periodic_set, np.ndarray):
393
        motif, cell = periodic_set, None
394
    else:
395
        motif, cell = periodic_set[0], periodic_set[1]
396
397
    return motif, cell, asymmetric_unit, multiplicities
398
399
400
def _collapse_into_groups(overlapping):
401
    """The vector `overlapping` indicates for each pair of items in a set whether 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
402
    or not the items overlap, in the shape of a condensed distance matrix. Returns
403
    a list of groups of indices where all items in the same group overlap."""
404
405
    overlapping = squareform(overlapping)
406
    group_nums = {} # row_ind: group number
407
    group = 0
408
    for i, row in enumerate(overlapping):
409
        if i not in group_nums:
410
            group_nums[i] = group
411
            group += 1
412
413
            for j in np.argwhere(row).T[0]:
414
                if j not in group_nums:
415
                    group_nums[j] = group_nums[i]
416
417
    groups = collections.defaultdict(list)
418
    for row_ind, group_num in sorted(group_nums.items()):
419
        groups[group_num].append(row_ind)
420
    groups = list(groups.values())
421
422
    return groups
423