Passed
Push — master ( c02a6e...9e2dce )
by Daniel
03:49
created

amd._nearest_neighbours._close_lattice_points()   A

Complexity

Conditions 1

Size

Total Lines 17
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 17
rs 9.95
c 0
b 0
f 0
cc 1
nop 4
1
"""Implements core function nearest_neighbours used for AMD and PDD
2
calculations.
3
"""
4
5
from typing import Tuple, Iterable
6
from itertools import product
7
8
import numba
9
import numpy as np
10
import numpy.typing as npt
11
from scipy.spatial import KDTree
12
from scipy.spatial.distance import cdist
13
14
15
def nearest_neighbours(
16
        motif: npt.NDArray,
17
        cell: npt.NDArray,
18
        x: npt.NDArray,
19
        k: int
20
) -> Tuple[npt.NDArray[np.float64], ...]:
21
    """Find the ``k`` nearest neighbours in a periodic set for points in
22
    ``x``.
23
24
    Given a periodic set described by ``motif`` and ``cell``, a query
25
    set of points ``x`` and an integer ``k``, find the ``k`` nearest
26
    neighbours in the periodic set for all points in ``x``. Return
27
    distances to neighbours in order, the point cloud generated during
28
    the search and the indices of which points in the cloud are the
29
    neighbours of points in ``x``.
30
31
    Parameters
32
    ----------
33
    motif : :class:`numpy.ndarray`
34
        Cartesian coordinates of the motif, shape (no points, dims).
35
    cell : :class:`numpy.ndarray`
36
        The unit cell as a square array, shape (dims, dims).
37
    x : :class:`numpy.ndarray`
38
        Array of points to query for neighbours. For AMD/PDD invariants
39
        this is the motif, or more commonly an asymmetric unit of it.
40
    k : int
41
        Number of nearest neighbours to find for each point in ``x``.
42
43
    Returns
44
    -------
45
    dists : numpy.ndarray
46
        Array shape ``(x.shape[0], k)`` of distances from points in
47
        ``x`` to their ``k`` nearest neighbours in the periodic set, in
48
        order. E.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
49
        n-th nearest neighbour in the periodic set.
50
    cloud : numpy.ndarray
51
        Collection of points in the periodic set that was generated
52
        during the nearest neighbour search.
53
    inds : numpy.ndarray
54
        Array shape ``(x.shape[0], k)`` containing the indices of
55
        nearest neighbours in ``cloud``. E.g. the n-th nearest neighbour
56
        to ``x[m]`` is ``cloud[inds[m][n]]``.
57
    """
58
59
    # Generate an initial cloud of enough points, at least k
60
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
61
    n_points = 0
62
    int_lat_cloud = []
63
    while n_points <= k:
64
        layer = next(int_lat_generator)
65
        n_points += layer.shape[0] * len(motif)
66
        int_lat_cloud.append(layer)
67
68
    # Add one layer from the lattice generator, on average this is faster
69
    int_lat_cloud.append(next(int_lat_generator))
70
    cloud = _int_lattice_to_cloud(motif, cell, np.concatenate(int_lat_cloud))
71
72
    # Find k neighbours for points in x
73
    dists_, inds = KDTree(
74
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
75
    ).query(x, k=k)
76
77
    # Generate more layers of lattice points until they are too large to
78
    # contain nearer neighbours than have already been found. For a lattice
79
    # point l, points in l + motif further away from x than |l| - max|p-p'|
80
    # (p in x, p' in motif), this is used to check if l is too far away.
81
    max_cdist = np.amax(cdist(x, motif))
82
    lattice_layers = []
83
    while True:
84
        lattice = _close_lattice_points(
85
            next(int_lat_generator), cell, dists_[:, -1], max_cdist
86
        )
87
        if lattice.size == 0:
88
            break
89
        lattice_layers.append(lattice)
90
91
    if lattice_layers:
92
        lattice_layers = np.concatenate(lattice_layers)
93
        cloud = np.vstack((
94
            cloud[np.unique(inds)], _lattice_to_cloud(motif, lattice_layers)
95
        ))
96
        dists_, inds = KDTree(
97
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
98
        ).query(x, k=k)
99
100
    return dists_, cloud, inds
101
102
103
104
105
106
def _generate_integer_lattice(dims: int) -> Iterable[npt.NDArray[np.float64]]:
107
    """Generate batches of integer lattice points. Each yield gives all
108
    points (that have not already been yielded) inside a sphere centered
109
    at the origin with radius d. d starts at 0 and increments by 1 on
110
    each loop.
111
112
    Parameters
113
    ----------
114
    dims : int
115
        The dimension of Euclidean space the lattice is in.
116
117
    Yields
118
    -------
119
    :class:`numpy.ndarray`
120
        Yields arrays of integer points in dims dimensional Euclidean
121
        space.
122
    """
123
124
    d = 0
125
126
    if dims == 1:
127
        yield np.array([[0]], dtype=np.float64)
128
        while True:
129
            d += 1
130
            yield np.array([[-d], [d]], dtype=np.float64)
131
132
    ymax = {}
133
    while True:
134
        positive_int_lattice = []
135
        while True:
136
            batch = []
137
            for xy in product(range(d + 1), repeat=dims-1):
138
                if xy not in ymax:
139
                    ymax[xy] = 0
140
                if _in_sphere(xy, ymax[xy], d):
141
                    batch.append((*xy, ymax[xy]))
142
                    ymax[xy] += 1
143
            if not batch:
144
                break
145
            positive_int_lattice.extend(batch)
146
147
        yield _reflect_positive_integer_lattice(np.array(positive_int_lattice))
148
        d += 1
149
150
151
@numba.njit(cache=True)
152
def _reflect_positive_integer_lattice(
153
        positive_int_lattice: npt.NDArray
154
) -> npt.NDArray[np.float64]:
155
    """Reflect points in the positive quadrant across all combinations
156
    of axes, without duplicating points that are invariant under
157
    reflections.
158
    """
159
160
    dims = positive_int_lattice.shape[-1]
161
    batches = []
162
    batches.extend(positive_int_lattice)
163
164
    for n_reflections in range(1, dims + 1):
165
166
        axes = np.arange(n_reflections)
167
        batches.extend(_reflect_in_axes(positive_int_lattice, axes))
168
169
        while True:
170
            i = n_reflections - 1
171
            for _ in range(n_reflections):
172
                if axes[i] != i + dims - n_reflections:
173
                    break
174
                i -= 1
175
            else:
176
                break
177
            axes[i] += 1
178
            for j in range(i + 1, n_reflections):
179
                axes[j] = axes[j-1] + 1
180
            batches.extend(_reflect_in_axes(positive_int_lattice, axes))
181
182
    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
183
    for i in range(len(batches)):
184
        int_lattice[i] = batches[i]
185
186
    return int_lattice
187
188
189
@numba.njit(cache=True)
190
def _reflect_in_axes(
191
        positive_int_lattice: npt.NDArray,
192
        axes: npt.NDArray
193
) -> npt.NDArray:
194
    """Reflect points in `positive_int_lattice` in the axes described by
195
    `axes`, without including invariant points.
196
    """
197
198
    not_on_axes = (positive_int_lattice[:, axes] == 0).sum(axis=-1) == 0
199
    int_lattice = positive_int_lattice[not_on_axes]
200
    int_lattice[:, axes] *= -1
201
    return int_lattice
202
203
204
@numba.njit(cache=True)
205
def _close_lattice_points(
206
        int_lattice: npt.NDArray,
207
        cell: npt.NDArray,
208
        max_nn_dists: npt.NDArray,
209
        max_cdist: float
210
) -> npt.NDArray[np.float64]:
211
    """Given integer lattice points, a unit cell, ``max_cdist`` (max of
212
    cdist(x, motif)) and ``max_nn_dist`` (max of the dists to k-th
213
    nearest neighbours found so far), return lattice points which are
214
    close enough such that the corresponding motif copy could contain
215
    nearest neighbours.
216
    """
217
218
    lattice = int_lattice @ cell
219
    bound = np.amax(max_nn_dists) + max_cdist
220
    return lattice[np.sqrt(np.sum(lattice ** 2, axis=-1)) < bound]
221
222
223
@numba.njit(cache=True)
224
def _lattice_to_cloud(
225
        motif: npt.NDArray,
226
        lattice: npt.NDArray
227
) -> npt.NDArray[np.float64]:
228
    """Transform a batch of non-integer lattice points (generated by
229
    _generate_integer_lattice then mutliplied by the cell) into a cloud
230
    of points from a periodic set with the motif and cell.
231
    """
232
233
    m = len(motif)
234
    layer = np.empty((m * len(lattice), motif.shape[-1]), dtype=np.float64)
235
    i1 = 0
236
    for translation in lattice:
237
        i2 = i1 + m
238
        layer[i1:i2] = motif + translation
239
        i1 = i2
240
    return layer
241
242
243
@numba.njit(cache=True)
244
def _int_lattice_to_cloud(
245
        motif: npt.NDArray,
246
        cell: npt.NDArray,
247
        int_lattice: npt.NDArray
248
) -> npt.NDArray[np.float64]:
249
    """Transform a batch of integer lattice points (generated by
250
    _generate_integer_lattice) into a cloud of points from a periodic
251
    set with the motif and cell.
252
    """
253
    return _lattice_to_cloud(motif, int_lattice @ cell)
254
255
256
@numba.njit(cache=True)
257
def _in_sphere(xy: Tuple[float, float], z: float, d: float) -> bool:
258
    """Return True if sum(i^2 for i in xy) + z^2 <= d^2."""
259
    s = z ** 2
260
    for val in xy:
261
        s += val ** 2
262
    return s <= d ** 2
263
264
265
def nearest_neighbours_minval(
266
        motif: npt.NDArray,
267
        cell: npt.NDArray,
268
        min_val: float
269
) -> npt.NDArray[np.float64]:
270
    """Return the same ``dists``/PDD matrix as ``nearest_neighbours``,
271
    but with enough columns such that all values in the last column are
272
    at least ``min_val``. Unlike ``nearest_neighbours``, does not take a
273
    query array ``x`` but only finds neighbours to motif points, and
274
    does not return the point cloud or indices of the nearest
275
    neighbours. Used in ``PDD_reconstructable``.
276
    """
277
    
278
    max_cdist = np.amax(cdist(motif, motif))
279
    # generate initial cloud of points, at least k + two more layers
280
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
281
    
282
    cloud = []
283
    for _ in range(3):
284
        cloud.append(_lattice_to_cloud(motif, next(int_lat_generator) @ cell))
285
    cloud = np.concatenate(cloud)
286
287
    dists_, inds = KDTree(
288
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
289
    ).query(motif, k=cloud.shape[0])
290
    dists = np.zeros_like(dists_, dtype=np.float64)
291
292
    # add layers & find k nearest neighbours until they don't change
293
    while True:
294
        if np.all(dists_[:, -1] >= min_val):
295
            col = np.argwhere(np.all(dists_ >= min_val, axis=0))[0][0] + 1
296
            if np.array_equal(dists[:, :col], dists_[:, :col]):
297
                break
298
        dists = dists_
299
        lattice = next(int_lat_generator) @ cell
300
        closest_dist_bound = np.linalg.norm(lattice, axis=-1) - max_cdist
301
        is_close = closest_dist_bound <= np.amax(dists_[:, -1])
302
        if not np.any(is_close):
303
            break
304
        cloud = np.vstack((cloud, _lattice_to_cloud(motif, lattice[is_close])))
305
        dists_, inds = KDTree(
306
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
307
        ).query(motif, k=cloud.shape[0])
308
309
    k = np.argwhere(np.all(dists >= min_val, axis=0))[0][0]
310
    return dists_[:, 1:k+1], cloud, inds
311
312
313
def generate_concentric_cloud(motif, cell):
314
    """
315
    Generates batches of points from a periodic set given by (motif,
316
    cell) which get successively further away from the origin.
317
318
    Each yield gives all points (that have not already been yielded)
319
    which lie in a unit cell whose corner lattice point was generated by
320
    ``generate_integer_lattice(motif.shape[1])``.
321
322
    Parameters
323
    ----------
324
    motif : :class:`numpy.ndarray`
325
        Cartesian representation of the motif, shape (no points, dims).
326
    cell : :class:`numpy.ndarray`
327
        Cartesian representation of the unit cell, shape (dims, dims).
328
329
    Yields
330
    -------
331
    :class:`numpy.ndarray`
332
        Yields arrays of points from the periodic set.
333
    """
334
335
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
336
    for layer in int_lat_generator:
337
        yield _int_lattice_to_cloud(motif, cell, layer)
338