1 | """Calculation of isometry invariants from periodic sets.""" |
||
2 | |||
3 | import warnings |
||
4 | import collections |
||
5 | from typing import Tuple, Union |
||
6 | import itertools |
||
7 | |||
8 | import numpy as np |
||
9 | import numpy.typing as npt |
||
10 | import numba |
||
11 | from scipy.spatial.distance import pdist, squareform |
||
12 | |||
13 | from ._types import FloatArray |
||
14 | from ._nearest_neighbors import nearest_neighbors, nearest_neighbors_minval |
||
15 | from .periodicset import PeriodicSet |
||
16 | from .utils import diameter |
||
17 | from .globals_ import MAX_DISORDER_CONFIGS |
||
18 | |||
19 | |||
20 | __all__ = [ |
||
21 | "PDD", |
||
22 | "AMD", |
||
23 | "ADA", |
||
24 | "PDA", |
||
25 | "PDD_to_AMD", |
||
26 | "AMD_finite", |
||
27 | "PDD_finite", |
||
28 | "PDD_reconstructable", |
||
29 | "AMD_estimate", |
||
30 | ] |
||
31 | |||
32 | |||
33 | def PDD( |
||
34 | pset: PeriodicSet, |
||
35 | k: int, |
||
36 | lexsort: bool = True, |
||
37 | collapse: bool = True, |
||
38 | collapse_tol: float = 1e-4, |
||
39 | return_row_data: bool = False, |
||
40 | ) -> Union[FloatArray, Tuple[FloatArray, list]]: |
||
41 | """Return the pointwise distance distribution (PDD) of a periodic |
||
42 | set (usually representing a crystal). |
||
43 | |||
44 | The PDD is a geometry based descriptor independent of choice of |
||
45 | motif and unit cell. It is a matrix with each row corresponding to a |
||
46 | point in the motif, starting with a weight followed by distances to |
||
47 | the k nearest neighbors of the point. |
||
48 | |||
49 | Parameters |
||
50 | ---------- |
||
51 | pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` |
||
52 | A periodic set (crystal). |
||
53 | k : int |
||
54 | Number of neighbors considered for each point in a unit cell. |
||
55 | The output has k + 1 columns with the first column containing |
||
56 | weights. |
||
57 | lexsort : bool, default True |
||
58 | Lexicographically order rows. |
||
59 | collapse: bool, default True |
||
60 | Collapse duplicate rows (within ``collapse_tol`` in the |
||
61 | Chebyshev metric). |
||
62 | collapse_tol: float, default 1e-4 |
||
63 | If two rows are closer than ``collapse_tol`` in the Chebyshev |
||
64 | metric, they are merged and weights are given to rows in |
||
65 | proportion to their frequency. |
||
66 | return_row_data: bool, default False |
||
67 | Return a tuple ``(pdd, groups)`` where ``groups`` contains |
||
68 | information about which rows in ``pdd`` correspond to which |
||
69 | points. If ``pset.asym_unit`` is None, then ``groups[i]`` |
||
70 | contains indices of points in ``pset.motif`` corresponding to |
||
71 | ``pdd[i]``. Otherwise, PDD rows correspond to points in the |
||
72 | asymmetric unit, and ``groups[i]`` contains indices pointing to |
||
73 | ``pset.asym_unit``. |
||
74 | |||
75 | Returns |
||
76 | ------- |
||
77 | pdd : :class:`numpy.ndarray` |
||
78 | The PDD of ``pset``, a :class:`numpy.ndarray` with ``k+1`` |
||
79 | columns. If ``return_row_data`` is True, returns a tuple |
||
80 | (:class:`numpy.ndarray`, list). |
||
81 | |||
82 | Examples |
||
83 | -------- |
||
84 | Make list of PDDs with ``k=100`` for crystals in data.cif:: |
||
85 | |||
86 | pdds = [] |
||
87 | for periodic_set in amd.CifReader('data.cif'): |
||
88 | pdd = amd.PDD(periodic_set, 100) |
||
89 | pdds.append(pdd) |
||
90 | |||
91 | Make list of PDDs with ``k=10`` for crystals in these CSD refcode |
||
92 | families (requires csd-python-api):: |
||
93 | |||
94 | pdds = [] |
||
95 | for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True): |
||
96 | pdds.append(amd.PDD(periodic_set, 10)) |
||
97 | |||
98 | Manually create a periodic set as a tuple (motif, cell):: |
||
99 | |||
100 | # simple cubic lattice |
||
101 | motif = np.array([[0,0,0]]) |
||
102 | cell = np.array([[1,0,0], [0,1,0], [0,0,1]]) |
||
103 | periodic_set = amd.PeriodicSet(motif, cell) |
||
104 | cubic_pdd = amd.PDD(periodic_set, 100) |
||
105 | """ |
||
106 | |||
107 | if not isinstance(pset, PeriodicSet): |
||
108 | raise ValueError( |
||
109 | f"Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}" |
||
110 | ) |
||
111 | |||
112 | weights, dists, groups = _PDD( |
||
113 | pset, k, lexsort=lexsort, collapse=collapse, collapse_tol=collapse_tol |
||
114 | ) |
||
115 | pdd = np.empty(shape=(len(dists), k + 1), dtype=np.float64) |
||
116 | pdd[:, 0] = weights |
||
117 | pdd[:, 1:] = dists |
||
118 | if return_row_data: |
||
119 | return pdd, groups |
||
120 | return pdd |
||
121 | |||
122 | |||
123 | def _PDD( |
||
124 | pset: PeriodicSet, |
||
125 | k: int, |
||
126 | lexsort: bool = True, |
||
127 | collapse: bool = True, |
||
128 | collapse_tol: float = 1e-4, |
||
129 | ) -> Tuple[FloatArray, FloatArray, list[list[int]]]: |
||
130 | """See PDD() for documentation. This core function always returns a |
||
131 | tuple (weights, dists, groups), with weights and dists to be merged |
||
132 | by PDD() and groups to be optionally returned. |
||
133 | """ |
||
134 | |||
135 | asym_unit = pset.motif[pset.asym_unit] |
||
136 | weights = pset.multiplicities / pset.motif.shape[0] |
||
137 | |||
138 | # Disordered structures |
||
139 | subs_disorder_info = {} # i: [inds masked] where i is sub disordered |
||
140 | |||
141 | if pset.disorder: |
||
142 | # Gather which disorder assemblies must be considered |
||
143 | _asym_mask = np.full((asym_unit.shape[0], ), fill_value=True) |
||
144 | asm_sizes = {} |
||
145 | for i, asm in enumerate(pset.disorder): |
||
146 | grps = asm.groups |
||
147 | |||
148 | # Ignore assmeblies with 1 group |
||
149 | if len(grps) < 2: |
||
150 | continue |
||
151 | |||
152 | # For substitutional disorder, mask all but one atom |
||
153 | elif asm.is_substitutional: |
||
154 | mask_inds = [grps[j].indices[0] for j in range(1, len(grps))] |
||
155 | keep = grps[0].indices[0] |
||
156 | subs_disorder_info[keep] = mask_inds |
||
157 | _asym_mask[mask_inds] = False |
||
158 | |||
159 | else: |
||
160 | asm_sizes[i] = len(grps) |
||
161 | |||
162 | asm_sizes_arr = np.array(list(asm_sizes.values())) |
||
163 | if _array_product_exceeds(asm_sizes_arr, MAX_DISORDER_CONFIGS): |
||
164 | warnings.warn( |
||
165 | f"Disorder configs exceeds limit " |
||
166 | f"amd.globals_.MAX_DISORDER_CONFIGS={MAX_DISORDER_CONFIGS}, " |
||
167 | "defaulting to majority occupancy config" |
||
168 | ) |
||
169 | configs = [[]] |
||
170 | for asm in pset.disorder: |
||
171 | i, _ = max(enumerate(asm.groups), key=lambda g: g[1].occupancy) |
||
172 | configs[0].append(i) |
||
173 | else: |
||
174 | configs = itertools.product(*(range(t) for t in asm_sizes.values())) |
||
0 ignored issues
–
show
Comprehensibility
Best Practice
introduced
by
![]() |
|||
175 | |||
176 | # One PDD for each disorder configuration |
||
177 | dists_list, inds_list = [], [] |
||
178 | for config_inds in configs: |
||
179 | |||
180 | # Mask groups not selected |
||
181 | asym_mask = _asym_mask.copy() |
||
182 | motif_mask = np.full((pset.motif.shape[0], ), fill_value=True) |
||
183 | for i, asm_ind in enumerate(asm_sizes.keys()): |
||
184 | for j, grp in enumerate(pset.disorder[asm_ind].groups): |
||
185 | if j != config_inds[i]: |
||
186 | for t in grp.indices: |
||
187 | asym_mask[t] = False |
||
188 | m_i = pset.asym_unit[t] |
||
189 | mul = pset.multiplicities[t] |
||
190 | motif_mask[m_i : m_i + mul] = False |
||
191 | |||
192 | dists = nearest_neighbors( |
||
193 | pset.motif[motif_mask], pset.cell, asym_unit[asym_mask], k + 1 |
||
194 | ) |
||
195 | dists_list.append(dists[:, 1:]) |
||
196 | inds_list.append(np.where(asym_mask)[0]) |
||
197 | |||
198 | dists = np.vstack(dists_list) |
||
199 | inds = list(np.concatenate(inds_list)) |
||
200 | weights = np.concatenate([weights[i] for i in inds_list]) |
||
201 | weights /= np.sum(weights) |
||
202 | |||
203 | else: |
||
204 | dists = nearest_neighbors(pset.motif, pset.cell, asym_unit, k + 1) |
||
205 | dists = dists[:, 1:] |
||
206 | inds = list(range(len(dists))) |
||
207 | |||
208 | # Collapse rows within tolerance |
||
209 | groups = None |
||
210 | if collapse: |
||
211 | weights, dists, group_labs = _merge_pdd_rows(weights, dists, collapse_tol) |
||
212 | if dists.shape[0] != len(group_labs): |
||
213 | groups = [[] for _ in range(weights.shape[0])] |
||
214 | for old_ind, new_ind in enumerate(group_labs): |
||
215 | groups[new_ind].append(int(inds[old_ind])) |
||
216 | |||
217 | if groups is None: |
||
218 | groups = [[int(i)] for i in inds] |
||
219 | |||
220 | # Add back substitutionally disordered sites to group info |
||
221 | if subs_disorder_info: |
||
222 | for i, masked_inds in subs_disorder_info.items(): |
||
223 | for grp in groups: |
||
224 | if i in grp: |
||
225 | grp.extend(masked_inds) |
||
226 | |||
227 | if lexsort: |
||
228 | lex_ordering = np.lexsort(dists.T[::-1]) |
||
229 | weights = weights[lex_ordering] |
||
230 | dists = dists[lex_ordering] |
||
231 | groups = [groups[i] for i in lex_ordering] |
||
232 | |||
233 | return weights, dists, groups |
||
234 | |||
235 | |||
236 | def AMD(pset: PeriodicSet, k: int) -> FloatArray: |
||
237 | """Return the average minimum distance (AMD) of a periodic set |
||
238 | (usually representing a crystal). |
||
239 | |||
240 | The AMD is the centroid or average of the PDD (pointwise distance |
||
241 | distribution) and hence is also a independent of choice of motif and |
||
242 | unit cell. It is a vector containing average distances from points |
||
243 | to k neighbouring points. |
||
244 | |||
245 | Parameters |
||
246 | ---------- |
||
247 | pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` |
||
248 | A periodic set (crystal). |
||
249 | k : int |
||
250 | Number of neighbors considered for each point in a unit cell. |
||
251 | |||
252 | Returns |
||
253 | ------- |
||
254 | :class:`numpy.ndarray` |
||
255 | The AMD of ``pset``, a :class:`numpy.ndarray` shape ``(k, )``. |
||
256 | |||
257 | Examples |
||
258 | -------- |
||
259 | Make list of AMDs with k = 100 for crystals in data.cif:: |
||
260 | |||
261 | amds = [] |
||
262 | for periodic_set in amd.CifReader('data.cif'): |
||
263 | amds.append(amd.AMD(periodic_set, 100)) |
||
264 | |||
265 | Make list of AMDs with k = 10 for crystals in these CSD refcode families:: |
||
266 | |||
267 | amds = [] |
||
268 | for periodic_set in amd.CSDReader(['HXACAN', 'ACSALA'], families=True): |
||
269 | amds.append(amd.AMD(periodic_set, 10)) |
||
270 | |||
271 | Manually create a periodic set as a tuple (motif, cell):: |
||
272 | |||
273 | # simple cubic lattice |
||
274 | motif = np.array([[0,0,0]]) |
||
275 | cell = np.array([[1,0,0], [0,1,0], [0,0,1]]) |
||
276 | periodic_set = amd.PeriodicSet(motif, cell) |
||
277 | cubic_amd = amd.AMD(periodic_set, 100) |
||
278 | """ |
||
279 | weights, dists, _ = _PDD(pset, k, lexsort=False, collapse=False) |
||
280 | return np.average(dists, weights=weights, axis=0) |
||
281 | |||
282 | |||
283 | @numba.njit(cache=True, fastmath=True) |
||
284 | def PDD_to_AMD(pdd: FloatArray) -> FloatArray: |
||
285 | """Calculate an AMD from a PDD, faster than computing both from |
||
286 | scratch. |
||
287 | |||
288 | Parameters |
||
289 | ---------- |
||
290 | pdd : :class:`numpy.ndarray` |
||
291 | The PDD of a periodic set as given by :class:`PDD() <.PDD>`. |
||
292 | Returns |
||
293 | ------- |
||
294 | :class:`numpy.ndarray` |
||
295 | The AMD of the periodic set, so that |
||
296 | ``amd.PDD_to_AMD(amd.PDD(pset)) == amd.AMD(pset)`` |
||
297 | """ |
||
298 | |||
299 | amd_ = np.empty((pdd.shape[-1] - 1,), dtype=np.float64) |
||
300 | for col in range(amd_.shape[0]): |
||
301 | v = 0 |
||
302 | for row in range(pdd.shape[0]): |
||
303 | v += pdd[row, 0] * pdd[row, col + 1] |
||
304 | amd_[col] = v |
||
305 | return amd_ |
||
306 | |||
307 | |||
308 | def AMD_finite(motif: FloatArray) -> FloatArray: |
||
309 | """Return the AMD of a finite m-point set up to k = m - 1. |
||
310 | |||
311 | Parameters |
||
312 | ---------- |
||
313 | motif : :class:`numpy.ndarray` |
||
314 | Collection of points. |
||
315 | |||
316 | Returns |
||
317 | ------- |
||
318 | :class:`numpy.ndarray` |
||
319 | The AMD of ``motif``, a vector shape ``(motif.shape[0] - 1, )``. |
||
320 | |||
321 | Examples |
||
322 | -------- |
||
323 | The (L-infinity) AMD distance between finite trapezium and kite |
||
324 | point sets, which have the same list of inter-point distances:: |
||
325 | |||
326 | trapezium = np.array([[0,0],[1,1],[3,1],[4,0]]) |
||
327 | kite = np.array([[0,0],[1,1],[1,-1],[4,0]]) |
||
328 | |||
329 | trap_amd = amd.AMD_finite(trapezium) |
||
330 | kite_amd = amd.AMD_finite(kite) |
||
331 | |||
332 | l_inf_dist = np.amax(np.abs(trap_amd - kite_amd)) |
||
333 | """ |
||
334 | |||
335 | dm = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:] |
||
336 | return np.average(dm, axis=0) |
||
337 | |||
338 | |||
339 | def PDD_finite( |
||
340 | motif: FloatArray, |
||
341 | lexsort: bool = True, |
||
342 | collapse: bool = True, |
||
343 | collapse_tol: float = 1e-4, |
||
344 | return_row_data: bool = False, |
||
345 | ) -> Union[FloatArray, Tuple[FloatArray, list]]: |
||
346 | """Return the PDD of a finite m-point set up to k = m - 1. |
||
347 | |||
348 | Parameters |
||
349 | ---------- |
||
350 | motif : :class:`numpy.ndarray` |
||
351 | Collection of points. |
||
352 | lexsort : bool, default True |
||
353 | Lexicographically order rows. |
||
354 | collapse: bool, default True |
||
355 | Collapse duplicate rows (within ``collapse_tol`` in the |
||
356 | Chebyshev metric). |
||
357 | collapse_tol: float, default 1e-4 |
||
358 | If two rows are closer than ``collapse_tol`` in the Chebyshev |
||
359 | metric, they are merged and weights are given to rows in |
||
360 | proportion to their frequency. |
||
361 | return_row_data: bool, default False |
||
362 | If True, return a tuple ``(pdd, groups)`` where ``groups[i]`` |
||
363 | contains indices of points in ``motif`` corresponding to |
||
364 | ``pdd[i]``. |
||
365 | |||
366 | Returns |
||
367 | ------- |
||
368 | pdd : :class:`numpy.ndarray` |
||
369 | The PDD of ``motif``, a :class:`numpy.ndarray` with ``k+1`` |
||
370 | columns. If ``return_row_data`` is True, returns a tuple |
||
371 | (:class:`numpy.ndarray`, list). |
||
372 | |||
373 | Examples |
||
374 | -------- |
||
375 | The PDD distance between finite trapezium and kite point sets, which |
||
376 | have the same list of inter-point distances:: |
||
377 | |||
378 | trapezium = np.array([[0,0],[1,1],[3,1],[4,0]]) |
||
379 | kite = np.array([[0,0],[1,1],[1,-1],[4,0]]) |
||
380 | |||
381 | trap_pdd = amd.PDD_finite(trapezium) |
||
382 | kite_pdd = amd.PDD_finite(kite) |
||
383 | |||
384 | dist = amd.EMD(trap_pdd, kite_pdd) |
||
385 | """ |
||
386 | |||
387 | m = motif.shape[0] |
||
388 | dists = np.sort(squareform(pdist(motif)), axis=-1)[:, 1:] |
||
389 | weights = np.full((m,), 1 / m) |
||
390 | groups = [[i] for i in range(len(dists))] |
||
391 | |||
392 | # TODO: use _merge_pdd_rows |
||
393 | if collapse: |
||
394 | overlapping = pdist(dists, metric="chebyshev") <= collapse_tol |
||
395 | if overlapping.any(): |
||
396 | groups = _collapse_into_groups(overlapping) |
||
397 | weights = np.array([np.sum(weights[group]) for group in groups]) |
||
398 | dists = np.array( |
||
399 | [np.average(dists[group], axis=0) for group in groups], dtype=np.float64 |
||
400 | ) |
||
401 | |||
402 | pdd = np.empty(shape=(len(weights), m), dtype=np.float64) |
||
403 | |||
404 | if lexsort: |
||
405 | lex_ordering = np.lexsort(np.rot90(dists)) |
||
406 | pdd[:, 0] = weights[lex_ordering] |
||
407 | pdd[:, 1:] = dists[lex_ordering] |
||
408 | if return_row_data: |
||
409 | groups = [groups[i] for i in lex_ordering] |
||
410 | else: |
||
411 | pdd[:, 0] = weights |
||
412 | pdd[:, 1:] = dists |
||
413 | |||
414 | if return_row_data: |
||
415 | return pdd, groups |
||
416 | return pdd |
||
417 | |||
418 | |||
419 | def PDD_reconstructable(pset: PeriodicSet, lexsort: bool = True) -> FloatArray: |
||
420 | """Return the PDD of a periodic set with ``k`` (number of columns) |
||
421 | large enough such that the periodic set can be reconstructed from |
||
422 | the PDD with :func:`amd.reconstruct.reconstruct`. Does NOT return |
||
423 | weights or collapse rows. |
||
424 | |||
425 | Parameters |
||
426 | ---------- |
||
427 | pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` |
||
428 | A periodic set (crystal). |
||
429 | lexsort : bool, default True |
||
430 | Lexicographically order rows. |
||
431 | |||
432 | Returns |
||
433 | ------- |
||
434 | pdd : :class:`numpy.ndarray` |
||
435 | The PDD of ``pset`` with enough columns to reconstruct ``pset`` |
||
436 | using :func:`amd.reconstruct.reconstruct`. |
||
437 | """ |
||
438 | |||
439 | if not isinstance(pset, PeriodicSet): |
||
440 | raise ValueError( |
||
441 | f"Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}" |
||
442 | ) |
||
443 | |||
444 | if pset.ndim not in (2, 3): |
||
445 | raise ValueError( |
||
446 | "Reconstructing from PDD is only possible for 2 and 3 dimensions." |
||
447 | ) |
||
448 | min_val = diameter(pset.cell) * 2 |
||
449 | pdd, _, _ = nearest_neighbors_minval(pset.motif, pset.cell, min_val) |
||
450 | if lexsort: |
||
451 | lex_ordering = np.lexsort(pdd.T[::-1]) |
||
452 | pdd = pdd[lex_ordering] |
||
453 | return pdd |
||
454 | |||
455 | |||
456 | def AMD_estimate(pset: PeriodicSet, k: int) -> FloatArray: |
||
457 | r"""Calculate an estimate of :class:`AMD <.AMD>` based on the |
||
458 | :class:`PPC <.periodicset.PeriodicSet.PPC>` of ``pset``. |
||
459 | |||
460 | Parameters |
||
461 | ---------- |
||
462 | pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` |
||
463 | A periodic set (crystal). |
||
464 | |||
465 | Returns |
||
466 | ------- |
||
467 | amd_est : :class:`numpy.ndarray` |
||
468 | An array shape (k, ), where ``amd_est[i]`` |
||
469 | :math:`= \text{PPC} \sqrt[n]{k}` in n dimensions, whose ratio |
||
470 | with AMD has been shown to converge to 1. |
||
471 | """ |
||
472 | |||
473 | if not isinstance(pset, PeriodicSet): |
||
474 | raise ValueError( |
||
475 | f"Expected {PeriodicSet.__name__}, got {pset.__class__.__name__}" |
||
476 | ) |
||
477 | arange = np.arange(1, k + 1, dtype=np.float64) |
||
478 | return pset.PPC() * np.power(arange, 1.0 / pset.ndim) |
||
479 | |||
480 | |||
481 | def PDA( |
||
482 | pset: PeriodicSet, |
||
483 | k: int, |
||
484 | lexsort: bool = True, |
||
485 | collapse: bool = True, |
||
486 | collapse_tol: float = 1e-4, |
||
487 | return_row_data: bool = False, |
||
488 | ) -> Union[FloatArray, Tuple[FloatArray, list]]: |
||
489 | """Return the pointwise deviation from asymptotic distribution, |
||
490 | essentially a normalisation of the pointwise distance distribution |
||
491 | of ``pset``. The PDA records how much the distances in the PDD |
||
492 | deviate from what is expected based on the asymptotic estimate. |
||
493 | |||
494 | The PDD of ``pset`` is a geometry based descriptor independent of |
||
495 | choice of motif and unit cell. Its asymptotic behaviour is well |
||
496 | understood and depends on the point density of the periodic set. |
||
497 | The PDA is the difference between the PDD and its asymptotic curve. |
||
498 | |||
499 | Parameters |
||
500 | ---------- |
||
501 | pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` |
||
502 | A periodic set (crystal). |
||
503 | k : int |
||
504 | Number of neighbors considered for each point in a unit cell. |
||
505 | The output has k + 1 columns with the first column containing |
||
506 | weights. |
||
507 | lexsort : bool, default True |
||
508 | Lexicographically order rows. |
||
509 | collapse: bool, default True |
||
510 | Collapse duplicate rows (within ``collapse_tol`` in the |
||
511 | Chebyshev metric). |
||
512 | collapse_tol: float, default 1e-4 |
||
513 | If two rows are closer than ``collapse_tol`` in the Chebyshev |
||
514 | metric, they are merged and weights are given to rows in |
||
515 | proportion to their frequency. |
||
516 | return_row_data: bool, default False |
||
517 | Return a tuple ``(pda, groups)`` where ``groups`` contains |
||
518 | information about which rows in ``pda`` correspond to which |
||
519 | points. If ``pset.asym_unit`` is None, then ``groups[i]`` |
||
520 | contains indices of points in ``pset.motif`` corresponding to |
||
521 | ``pda[i]``. Otherwise, PDA rows correspond to points in the |
||
522 | asymmetric unit, and ``groups[i]`` contains indices pointing to |
||
523 | ``pset.asym_unit``. |
||
524 | |||
525 | Returns |
||
526 | ------- |
||
527 | pda : :class:`numpy.ndarray` |
||
528 | The PDA of ``pset``, a :class:`numpy.ndarray` with ``k+1`` |
||
529 | columns. If ``return_row_data`` is True, returns a tuple |
||
530 | (:class:`numpy.ndarray`, list). |
||
531 | """ |
||
532 | pdd, grps = PDD( |
||
533 | pset, |
||
534 | k, |
||
535 | collapse=collapse, |
||
536 | collapse_tol=collapse_tol, |
||
537 | lexsort=lexsort, |
||
538 | return_row_data=True, |
||
539 | ) |
||
540 | pdd[:, 1:] -= AMD_estimate(pset, k) |
||
541 | if return_row_data: |
||
542 | return pdd, grps |
||
543 | return pdd |
||
544 | |||
545 | |||
546 | def ADA(pset: PeriodicSet, k: int) -> FloatArray: |
||
547 | """Return the average deviation from asymptotic, essentially a |
||
548 | normalisation of the average minimum distance of ``pset``. The ADA |
||
549 | records how much the distances in the AMD deviate from what is |
||
550 | expected based on the asymptotic estimate. |
||
551 | |||
552 | The AMD of ``pset`` is a geometry based descriptor independent of |
||
553 | choice of motif and unit cell. Its asymptotic behaviour is well |
||
554 | understood and depends on the point density of the periodic set. |
||
555 | The ADA is the difference between the AMD and its asymptotic curve. |
||
556 | |||
557 | Parameters |
||
558 | ---------- |
||
559 | pset : :class:`amd.PeriodicSet <.periodicset.PeriodicSet>` |
||
560 | A periodic set (crystal). |
||
561 | k : int |
||
562 | Number of neighbors considered for each point in a unit cell. |
||
563 | |||
564 | Returns |
||
565 | ------- |
||
566 | :class:`numpy.ndarray` |
||
567 | The ADA of ``pset``, a :class:`numpy.ndarray` shape ``(k, )``. |
||
568 | """ |
||
569 | return AMD(pset, k) - AMD_estimate(pset, k) |
||
570 | |||
571 | |||
572 | @numba.njit(cache=True, fastmath=True) |
||
573 | def _array_product_exceeds(values, limit): |
||
574 | """Returns False if np.prod(values) > limit.""" |
||
575 | tot = 1 |
||
576 | for i in range(len(values)): |
||
577 | tot *= values[i] |
||
578 | if tot > limit: |
||
579 | return True |
||
580 | return False |
||
581 | |||
582 | |||
583 | @numba.njit(cache=True, fastmath=True) |
||
584 | def _merge_pdd_rows(weights, dists, collapse_tol): |
||
585 | """Collpases weights & rows of a PDD, and return an array of group |
||
586 | labels (new indices of old rows).""" |
||
587 | |||
588 | n, k = dists.shape |
||
589 | group_labels = np.empty((n,), dtype=np.int64) |
||
590 | done = set() |
||
591 | group = 0 |
||
592 | |||
593 | for i in range(n): |
||
594 | if i in done: |
||
595 | continue |
||
596 | |||
597 | group_labels[i] = group |
||
598 | |||
599 | for j in range(i + 1, n): |
||
600 | if j in done: |
||
601 | continue |
||
602 | |||
603 | grouped = True |
||
604 | for i_ in range(k): |
||
605 | v = np.abs(dists[i, i_] - dists[j, i_]) |
||
606 | if v > collapse_tol: |
||
607 | grouped = False |
||
608 | break |
||
609 | |||
610 | if grouped: |
||
611 | group_labels[j] = group |
||
612 | done.add(j) |
||
613 | |||
614 | group += 1 |
||
615 | |||
616 | if group == n: |
||
617 | return weights, dists, group_labels |
||
618 | |||
619 | weights_ = np.zeros((group,), dtype=np.float64) |
||
620 | dists_ = np.zeros((group, k), dtype=np.float64) |
||
621 | group_counts = np.zeros((group,), dtype=np.int64) |
||
622 | |||
623 | for i in range(n): |
||
624 | row = group_labels[i] |
||
625 | weights_[row] += weights[i] |
||
626 | dists_[row] += dists[i] |
||
627 | group_counts[row] += 1 |
||
628 | |||
629 | for i in range(group): |
||
630 | dists_[i] /= group_counts[i] |
||
631 | |||
632 | return weights_, dists_, group_labels |
||
633 | |||
634 | |||
635 | def _collapse_into_groups(overlapping: npt.NDArray[np.bool_]) -> list: |
||
636 | """Return a list of groups of indices where all indices in the same |
||
637 | group overlap. ``overlapping`` indicates for each pair of items in a |
||
638 | set whether or not the items overlap, in the shape of a condensed |
||
639 | distance matrix. |
||
640 | """ |
||
641 | |||
642 | overlapping = squareform(overlapping) |
||
643 | group_nums = {} |
||
644 | group = 0 |
||
645 | for i, row in enumerate(overlapping): |
||
646 | if i not in group_nums: |
||
647 | group_nums[i] = group |
||
648 | group += 1 |
||
649 | for j in np.argwhere(row).T[0]: |
||
650 | if j not in group_nums: |
||
651 | group_nums[j] = group_nums[i] |
||
652 | |||
653 | groups = collections.defaultdict(list) |
||
654 | for row_ind, group_num in sorted(group_nums.items()): |
||
655 | groups[group_num].append(row_ind) |
||
656 | |||
657 | return list(groups.values()) |
||
658 |