Passed
Push — master ( 7c2d4e...5cae11 )
by Daniel
06:09
created

nearest_neighbours_minval()   B

Complexity

Conditions 6

Size

Total Lines 48
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 31
dl 0
loc 48
rs 8.2026
c 0
b 0
f 0
cc 6
nop 3
1
"""Implements core function nearest_neighbours used for AMD and PDD
2
calculations.
3
"""
4
5
from typing import Tuple, Iterable
6
from itertools import product
7
8
import numba
9
import numpy as np
10
import numpy.typing as npt
11
from scipy.spatial import KDTree
12
from scipy.spatial.distance import cdist
13
14
15
def nearest_neighbours(
16
        motif: npt.NDArray,
17
        cell: npt.NDArray,
18
        x: npt.NDArray,
19
        k: int
20
) -> Tuple[npt.NDArray[np.float64], ...]:
21
    """Find the ``k`` nearest neighbours in a periodic set for points in
22
    ``x``.
23
24
    Given a periodic set described by ``motif`` and ``cell``, a query
25
    set of points ``x`` and an integer ``k``, find the ``k`` nearest
26
    neighbours in the periodic set for all points in ``x``. Return
27
    distances to neighbours in order, the point cloud generated during
28
    the search and the indices of which points in the cloud are the
29
    neighbours of points in ``x``.
30
31
    Parameters
32
    ----------
33
    motif : :class:`numpy.ndarray`
34
        Cartesian coordinates of the motif, shape (no points, dims).
35
    cell : :class:`numpy.ndarray`
36
        The unit cell as a square array, shape (dims, dims).
37
    x : :class:`numpy.ndarray`
38
        Array of points to query for neighbours. For AMD/PDD invariants
39
        this is the motif, or more commonly an asymmetric unit of it.
40
    k : int
41
        Number of nearest neighbours to find for each point in ``x``.
42
43
    Returns
44
    -------
45
    dists : numpy.ndarray
46
        Array shape ``(x.shape[0], k)`` of distances from points in
47
        ``x`` to their ``k`` nearest neighbours in the periodic set, in
48
        order. E.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
49
        n-th nearest neighbour in the periodic set.
50
    cloud : numpy.ndarray
51
        Collection of points in the periodic set that was generated
52
        during the nearest neighbour search.
53
    inds : numpy.ndarray
54
        Array shape ``(x.shape[0], k)`` containing the indices of
55
        nearest neighbours in ``cloud``. E.g. the n-th nearest neighbour
56
        to ``x[m]`` is ``cloud[inds[m][n]]``.
57
    """
58
59
    # Generate a cloud of lattice points such that lattice + motif has k points
60
    if cell.shape[0] == 3:
61
        int_lat_generator = _generate_integer_lattice_3D()
62
    else:
63
        int_lat_generator = _generate_integer_lattice(cell.shape[0])
64
65
    n_points = 0
66
    int_lat_cloud = []
67
    while n_points <= k:
68
        layer = next(int_lat_generator)
69
        n_points += layer.shape[0] * len(motif)
70
        int_lat_cloud.append(layer)
71
72
    # Add one layer to the lattice, on average this is faster
73
    int_lat_cloud.append(next(int_lat_generator))
74
    cloud = _int_lattice_to_cloud(motif, cell, np.concatenate(int_lat_cloud))
75
76
    # Find k neighbours in the point cloud for points in x
77
    dists_, inds = KDTree(
78
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
79
    ).query(x, k=k)
80
81
    # Generate layers of lattice points until they are too far away to give
82
    # nearer neighbours than have already been found. For a lattice point l,
83
    # points in l + motif are further away from x than |l| - max|p-p'| (where
84
    # p in x, p' in motif), used to check if l is too far away.
85
    motif_diameter = np.amax(cdist(x, motif))
86
    lattice_layers = []
87
    while True:
88
        lattice = _close_lattice_points(
89
            next(int_lat_generator), cell, dists_[:, -1], motif_diameter
90
        )
91
        if lattice.size == 0:
92
            break
93
        lattice_layers.append(lattice)
94
95
    if lattice_layers:
96
        lattice_layers = np.concatenate(lattice_layers)
97
        cloud = np.vstack((
98
            cloud[np.unique(inds)], _lattice_to_cloud(motif, lattice_layers)
99
        ))
100
        dists_, inds = KDTree(
101
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
102
        ).query(x, k=k)
103
104
    return dists_, cloud, inds
105
106
107
def _generate_integer_lattice(dims: int) -> Iterable[npt.NDArray[np.float64]]:
108
    """Generate batches of integer lattice points. Each yield gives all
109
    points (that have not already been yielded) inside a sphere centered
110
    at the origin with radius d; d starts at 0 and increments by 1 on
111
    each loop.
112
113
    Parameters
114
    ----------
115
    dims : int
116
        The dimension of Euclidean space the lattice is in.
117
118
    Yields
119
    -------
120
    :class:`numpy.ndarray`
121
        Yields arrays of integer points in `dims`-dimensional Euclidean
122
        space.
123
    """
124
125
    d = 0
126
127
    if dims == 1:
128
        yield np.array([[0]], dtype=np.float64)
129
        while True:
130
            d += 1
131
            yield np.array([[-d], [d]], dtype=np.float64)
132
133
    ymax = {}
134
    while True:
135
        positive_int_lattice = []
136
        while True:
137
            batch = False
138
            for xy in product(range(d + 1), repeat=dims-1):
139
                if xy not in ymax:
140
                    ymax[xy] = 0
141
                if _in_sphere(xy, ymax[xy], d):
142
                    positive_int_lattice.append((*xy, ymax[xy]))
143
                    batch = True
144
                    ymax[xy] += 1
145
            if not batch:
146
                break
147
        yield _reflect_positive_integer_lattice(np.array(positive_int_lattice))
148
        d += 1
149
150
151
@numba.njit(cache=True)
152
def _generate_integer_lattice_3D() -> Iterable[npt.NDArray[np.float64]]:
153
    """3D specific version of _generate_integer_lattice() which is
154
    accelerated by numba. 3D is the most common case and the function is
155
    difficult to generalise to any dimension with numba.
156
157
    Yields
158
    -------
159
    :class:`numpy.ndarray`
160
        Yields arrays of integer points in 3-dimensional Euclidean
161
        space.
162
    """
163
164
    d = 0
165
    ymax = {}
166
    while True:
167
        positive_int_lattice = []
168
        while True:
169
            batch = False
170
            for x in range(d + 1):
171
                for y in range(d + 1):
172
                    xy = (x, y)
173
                    if xy not in ymax:
174
                        ymax[xy] = 0
175
                    if x ** 2 + y ** 2 + ymax[xy] ** 2 <= d ** 2:
176
                        positive_int_lattice.append((x, y, ymax[xy]))
177
                        batch = True
178
                        ymax[xy] += 1
179
            if not batch:
180
                break
181
        yield _reflect_positive_integer_lattice(np.array(positive_int_lattice))
182
        d += 1
183
184
185
@numba.njit(cache=True)
186
def _reflect_positive_integer_lattice(
187
        positive_int_lattice: npt.NDArray
188
) -> npt.NDArray[np.float64]:
189
    """Reflect points in the positive quadrant across all combinations
190
    of axes, without duplicating points that are invariant under
191
    reflections.
192
    """
193
194
    dims = positive_int_lattice.shape[-1]
195
    batches = []
196
    batches.extend(positive_int_lattice)
197
198
    for n_reflections in range(1, dims + 1):
199
200
        axes = np.arange(n_reflections)
201
        batches.extend(_reflect_in_axes(positive_int_lattice, axes))
202
203
        while True:
204
            i = n_reflections - 1
205
            for _ in range(n_reflections):
206
                if axes[i] != i + dims - n_reflections:
207
                    break
208
                i -= 1
209
            else:
210
                break
211
            axes[i] += 1
212
            for j in range(i + 1, n_reflections):
213
                axes[j] = axes[j-1] + 1
214
            batches.extend(_reflect_in_axes(positive_int_lattice, axes))
215
216
    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
217
    for i in range(len(batches)):
218
        int_lattice[i] = batches[i]
219
220
    return int_lattice
221
222
223
@numba.njit(cache=True)
224
def _reflect_in_axes(
225
        positive_int_lattice: npt.NDArray,
226
        axes: npt.NDArray
227
) -> npt.NDArray:
228
    """Reflect points in `positive_int_lattice` in the axes described by
229
    `axes`, without duplicating invariant points.
230
    """
231
232
    not_on_axes = (positive_int_lattice[:, axes] == 0).sum(axis=-1) == 0
233
    int_lattice = positive_int_lattice[not_on_axes]
234
    int_lattice[:, axes] *= -1
235
    return int_lattice
236
237
238
@numba.njit(cache=True)
239
def _close_lattice_points(
240
        int_lattice: npt.NDArray,
241
        cell: npt.NDArray,
242
        max_nn_dists: npt.NDArray,
243
        max_cdist: float
244
) -> npt.NDArray[np.float64]:
245
    """Given integer lattice points, a unit cell, ``max_cdist`` (max of
246
    cdist(x, motif)) and ``max_nn_dist`` (max of the dists to k-th
247
    nearest neighbours found so far), return lattice points which are
248
    close enough such that the corresponding motif copy could contain
249
    nearest neighbours.
250
    """
251
252
    lattice = int_lattice @ cell
253
    bound = np.amax(max_nn_dists) + max_cdist
254
    return lattice[np.sqrt(np.sum(lattice ** 2, axis=-1)) < bound]
255
256
257
@numba.njit(cache=True)
258
def _lattice_to_cloud(
259
        motif: npt.NDArray,
260
        lattice: npt.NDArray
261
) -> npt.NDArray[np.float64]:
262
    """Transform a batch of non-integer lattice points (generated by
263
    _generate_integer_lattice then mutliplied by the cell) into a cloud
264
    of points from a periodic set with the motif and cell.
265
    """
266
267
    m = len(motif)
268
    layer = np.empty((m * len(lattice), motif.shape[-1]), dtype=np.float64)
269
    i1 = 0
270
    for translation in lattice:
271
        i2 = i1 + m
272
        layer[i1:i2] = motif + translation
273
        i1 = i2
274
    return layer
275
276
277
@numba.njit(cache=True)
278
def _int_lattice_to_cloud(
279
        motif: npt.NDArray,
280
        cell: npt.NDArray,
281
        int_lattice: npt.NDArray
282
) -> npt.NDArray[np.float64]:
283
    """Transform a batch of integer lattice points (generated by
284
    _generate_integer_lattice) into a cloud of points from a periodic
285
    set with the motif and cell.
286
    """
287
    return _lattice_to_cloud(motif, int_lattice @ cell)
288
289
290
@numba.njit(cache=True)
291
def _in_sphere(xy: Tuple[float, float], z: float, d: float) -> bool:
292
    """True if sum(i^2 for i in xy) + z^2 <= d^2."""
293
294
    s = z ** 2
295
    for val in xy:
296
        s += val ** 2
297
    return s <= d ** 2
298
299
300
def nearest_neighbours_minval(
301
        motif: npt.NDArray,
302
        cell: npt.NDArray,
303
        min_val: float
304
) -> npt.NDArray[np.float64]:
305
    """Return the same ``dists``/PDD matrix as ``nearest_neighbours``,
306
    but with enough columns such that all values in the last column are
307
    at least ``min_val``. Unlike ``nearest_neighbours``, does not take a
308
    query array ``x`` but only finds neighbours to motif points, and
309
    does not return the point cloud or indices of the nearest
310
    neighbours. Used in ``PDD_reconstructable``.
311
    """
312
    
313
    
314
    # Generate initial cloud of points from the periodic set
315
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
316
    cloud = []
317
    for _ in range(3):
318
        cloud.append(_lattice_to_cloud(motif, next(int_lat_generator) @ cell))
319
    cloud = np.concatenate(cloud)
320
321
    # Find k neighbours in the point cloud for points in motif
322
    dists_, inds = KDTree(
323
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
324
    ).query(motif, k=cloud.shape[0])
325
    dists = np.zeros_like(dists_, dtype=np.float64)
326
327
    # Add layers & find k nearest neighbours until all distances smaller than
328
    # min_val don't change
329
    max_cdist = np.amax(cdist(motif, motif))
330
    while True:
331
        if np.all(dists_[:, -1] >= min_val):
332
            col = np.argwhere(np.all(dists_ >= min_val, axis=0))[0][0] + 1
333
            if np.array_equal(dists[:, :col], dists_[:, :col]):
334
                break
335
        dists = dists_
336
        lattice = next(int_lat_generator) @ cell
337
        closest_dist_bound = np.linalg.norm(lattice, axis=-1) - max_cdist
338
        is_close = closest_dist_bound <= np.amax(dists_[:, -1])
339
        if not np.any(is_close):
340
            break
341
        cloud = np.vstack((cloud, _lattice_to_cloud(motif, lattice[is_close])))
342
        dists_, inds = KDTree(
343
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
344
        ).query(motif, k=cloud.shape[0])
345
346
    k = np.argwhere(np.all(dists >= min_val, axis=0))[0][0]
347
    return dists_[:, 1:k+1], cloud, inds
348
349
350
def generate_concentric_cloud(motif, cell):
351
    """Generates batches of points from a periodic set given by (motif,
352
    cell) which get successively further away from the origin.
353
354
    Each yield gives all points (that have not already been yielded)
355
    which lie in a unit cell whose corner lattice point was generated by
356
    ``generate_integer_lattice(motif.shape[1])``.
357
358
    Parameters
359
    ----------
360
    motif : :class:`numpy.ndarray`
361
        Cartesian representation of the motif, shape (no points, dims).
362
    cell : :class:`numpy.ndarray`
363
        Cartesian representation of the unit cell, shape (dims, dims).
364
365
    Yields
366
    -------
367
    :class:`numpy.ndarray`
368
        Yields arrays of points from the periodic set.
369
    """
370
371
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
372
    for layer in int_lat_generator:
373
        yield _lattice_to_cloud(motif, layer @ cell)
374