Test Failed
Push — master ( 37d7fb...c02a6e )
by Daniel
07:38
created

amd._nearest_neighbours._close_lattice_points()   A

Complexity

Conditions 1

Size

Total Lines 17
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 17
rs 9.95
c 0
b 0
f 0
cc 1
nop 4
1
"""Implements core function nearest_neighbours used for AMD and PDD
2
calculations.
3
"""
4
5
from typing import Tuple, Iterable
6
from itertools import product
7
8
import numba
9
import numpy as np
10
import numpy.typing as npt
11
from scipy.spatial import KDTree
12
from scipy.spatial.distance import cdist
13
14
15
def nearest_neighbours(
16
        motif: npt.NDArray,
17
        cell: npt.NDArray,
18
        x: npt.NDArray,
19
        k: int
20
) -> Tuple[npt.NDArray[np.float64], ...]:
21
    """Find the ``k`` nearest neighbours in a periodic set for points in
22
    ``x``.
23
24
    Given a periodic set described by ``motif`` and ``cell``, a query
25
    set of points ``x`` and an integer ``k``, find the ``k`` nearest
26
    neighbours in the periodic set for all points in ``x``. Return
27
    distances to neighbours in order, the point cloud generated during
28
    the search and the indices of which points in the cloud are the
29
    neighbours of points in ``x``.
30
31
    Parameters
32
    ----------
33
    motif : :class:`numpy.ndarray`
34
        Cartesian coordinates of the motif, shape (no points, dims).
35
    cell : :class:`numpy.ndarray`
36
        The unit cell as a square array, shape (dims, dims).
37
    x : :class:`numpy.ndarray`
38
        Array of points to query for neighbours. For AMD/PDD invariants
39
        this is the motif, or more commonly an asymmetric unit of it.
40
    k : int
41
        Number of nearest neighbours to find for each point in ``x``.
42
43
    Returns
44
    -------
45
    dists : numpy.ndarray
46
        Array shape ``(x.shape[0], k)`` of distances from points in
47
        ``x`` to their ``k`` nearest neighbours in the periodic set, in
48
        order. E.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
49
        n-th nearest neighbour in the periodic set.
50
    cloud : numpy.ndarray
51
        Collection of points in the periodic set that was generated
52
        during the nearest neighbour search.
53
    inds : numpy.ndarray
54
        Array shape ``(x.shape[0], k)`` containing the indices of
55
        nearest neighbours in ``cloud``. E.g. the n-th nearest neighbour
56
        to ``x[m]`` is ``cloud[inds[m][n]]``.
57
    """
58
59
    # Generate an initial cloud of enough points, at least k
60
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
61
    n_points = 0
62
    int_lat_cloud = []
63
    while n_points <= k:
64
        layer = next(int_lat_generator)
65
        n_points += layer.shape[0] * len(motif)
66
        int_lat_cloud.append(layer)
67
68
    # Add one layer from the lattice generator, on average this is faster
69
    int_lat_cloud.append(next(int_lat_generator))
70
    cloud = _int_lattice_to_cloud(motif, cell, np.concatenate(int_lat_cloud))
71
72
    # Find k neighbours for points in x
73
    dists_, inds = KDTree(
74
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
75
    ).query(x, k=k)
76
77
    # Generate more layers of lattice points until they are too large to
78
    # contain nearer neighbours than have already been found. For a lattice
79
    # point l, points in l + motif further away from x than |l| - max|p-p'|
80
    # (p in x, p' in motif), this is used to check if l is too far away.
81
    max_cdist = np.amax(cdist(x, motif))
82
    lattice_layers = []
83
    while True:
84
        lattice = _close_lattice_points(
85
            next(int_lat_generator), cell, dists_[:, -1], max_cdist
86
        )
87
        if lattice.size == 0:
88
            break
89
        lattice_layers.append(lattice)
90
91
    if lattice_layers:
92
        lattice_layers = np.concatenate(lattice_layers)
93
        cloud = np.vstack((
94
            cloud[np.unique(inds)], _lattice_to_cloud(motif, lattice_layers)
95
        ))
96
        dists_, inds = KDTree(
97
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
98
        ).query(x, k=k)
99
100
    return dists_, cloud, inds
101
102
103
def _generate_integer_lattice(dims: int) -> Iterable[npt.NDArray[np.float64]]:
104
    """Generate batches of integer lattice points. Each yield gives all
105
    points (that have not already been yielded) inside a sphere centered
106
    at the origin with radius d. d starts at 0 and increments by 1 on
107
    each loop.
108
109
    Parameters
110
    ----------
111
    dims : int
112
        The dimension of Euclidean space the lattice is in.
113
114
    Yields
115
    -------
116
    :class:`numpy.ndarray`
117
        Yields arrays of integer points in dims dimensional Euclidean
118
        space.
119
    """
120
121
    d = 0
122
123
    if dims == 1:
124
        yield np.array([[0]], dtype=np.float64)
125
        while True:
126
            d += 1
127
            yield np.array([[-d], [d]], dtype=np.float64)
128
129
    ymax = {}
130
    while True:
131
        positive_int_lattice = []
132
        while True:
133
            batch = []
134
            for xy in product(range(d + 1), repeat=dims-1):
135
                if xy not in ymax:
136
                    ymax[xy] = 0
137
                if _in_sphere(xy, ymax[xy], d):
138
                    batch.append((*xy, ymax[xy]))
139
                    ymax[xy] += 1
140
            if not batch:
141
                break
142
            positive_int_lattice.extend(batch)
143
144
        yield _reflect_positive_integer_lattice(np.array(positive_int_lattice))
145
        d += 1
146
147
148
@numba.njit(cache=True)
149
def _reflect_positive_integer_lattice(
150
        positive_int_lattice: npt.NDArray
151
) -> npt.NDArray[np.float64]:
152
    """Reflect points in the positive quadrant across all combinations
153
    of axes, without duplicating points that are invariant under
154
    reflections.
155
    """
156
157
    dims = positive_int_lattice.shape[-1]
158
    batches = []
159
    batches.extend(positive_int_lattice)
160
161
    for n_reflections in range(1, dims + 1):
162
163
        axes = np.arange(n_reflections)
164
        batches.extend(_reflect_in_axes(positive_int_lattice, axes))
165
166
        while True:
167
            i = n_reflections - 1
168
            for _ in range(n_reflections):
169
                if axes[i] != i + dims - n_reflections:
170
                    break
171
                i -= 1
172
            else:
173
                break
174
            axes[i] += 1
175
            for j in range(i + 1, n_reflections):
176
                axes[j] = axes[j-1] + 1
177
            batches.extend(_reflect_in_axes(positive_int_lattice, axes))
178
179
    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
180
    for i in range(len(batches)):
181
        int_lattice[i] = batches[i]
182
183
    return int_lattice
184
185
186
@numba.njit(cache=True)
187
def _reflect_in_axes(
188
        positive_int_lattice: npt.NDArray,
189
        axes: npt.NDArray
190
) -> npt.NDArray:
191
    """Reflect points in `positive_int_lattice` in the axes described by
192
    `axes`, without including invariant points.
193
    """
194
195
    not_on_axes = (positive_int_lattice[:, axes] == 0).sum(axis=-1) == 0
196
    int_lattice = positive_int_lattice[not_on_axes]
197
    int_lattice[:, axes] *= -1
198
    return int_lattice
199
200
201
@numba.njit(cache=True)
202
def _close_lattice_points(
203
        int_lattice: npt.NDArray,
204
        cell: npt.NDArray,
205
        max_nn_dists: npt.NDArray,
206
        max_cdist: float
207
) -> npt.NDArray[np.float64]:
208
    """Given integer lattice points, a unit cell, ``max_cdist`` (max of
209
    cdist(x, motif)) and ``max_nn_dist`` (max of the dists to k-th
210
    nearest neighbours found so far), return lattice points which are
211
    close enough such that the corresponding motif copy could contain
212
    nearest neighbours.
213
    """
214
215
    lattice = int_lattice @ cell
216
    bound = np.amax(max_nn_dists) + max_cdist
217
    return lattice[np.sqrt(np.sum(lattice ** 2, axis=-1)) < bound]
218
219
220
@numba.njit(cache=True)
221
def _lattice_to_cloud(
222
        motif: npt.NDArray,
223
        lattice: npt.NDArray
224
) -> npt.NDArray[np.float64]:
225
    """Transform a batch of non-integer lattice points (generated by
226
    _generate_integer_lattice then mutliplied by the cell) into a cloud
227
    of points from a periodic set with the motif and cell.
228
    """
229
230
    m = len(motif)
231
    layer = np.empty((m * len(lattice), motif.shape[-1]), dtype=np.float64)
232
    i1 = 0
233
    for translation in lattice:
234
        i2 = i1 + m
235
        layer[i1:i2] = motif + translation
236
        i1 = i2
237
    return layer
238
239
240
@numba.njit(cache=True)
241
def _int_lattice_to_cloud(
242
        motif: npt.NDArray,
243
        cell: npt.NDArray,
244
        int_lattice: npt.NDArray
245
) -> npt.NDArray[np.float64]:
246
    """Transform a batch of integer lattice points (generated by
247
    _generate_integer_lattice) into a cloud of points from a periodic
248
    set with the motif and cell.
249
    """
250
    return _lattice_to_cloud(motif, int_lattice @ cell)
251
252
253
@numba.njit(cache=True)
254
def _in_sphere(xy: Tuple[float, float], z: float, d: float) -> bool:
255
    """Return True if sum(i^2 for i in xy) + z^2 <= d^2."""
256
    s = z ** 2
257
    for val in xy:
258
        s += val ** 2
259
    return s <= d ** 2
260
261
262
def nearest_neighbours_minval(
263
        motif: npt.NDArray,
264
        cell: npt.NDArray,
265
        min_val: float
266
) -> npt.NDArray[np.float64]:
267
    """Return the same ``dists``/PDD matrix as ``nearest_neighbours``,
268
    but with enough columns such that all values in the last column are
269
    at least ``min_val``. Unlike ``nearest_neighbours``, does not take a
270
    query array ``x`` but only finds neighbours to motif points, and
271
    does not return the point cloud or indices of the nearest
272
    neighbours. Used in ``PDD_reconstructable``.
273
    """
274
    
275
    max_cdist = np.amax(cdist(motif, motif))
276
    # generate initial cloud of points, at least k + two more layers
277
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
278
    
279
    cloud = []
280
    for _ in range(3):
281
        cloud.append(_lattice_to_cloud(motif, next(int_lat_generator) @ cell))
282
    cloud = np.concatenate(cloud)
283
284
    dists_, inds = KDTree(
285
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
286
    ).query(motif, k=cloud.shape[0])
287
    dists = np.zeros_like(dists_, dtype=np.float64)
288
289
    # add layers & find k nearest neighbours until they don't change
290
    while True:
291
        if np.all(dists_[:, -1] >= min_val):
292
            col = np.argwhere(np.all(dists_ >= min_val, axis=0))[0][0] + 1
293
            if np.array_equal(dists[:, :col], dists_[:, :col]):
294
                break
295
        dists = dists_
296
        lattice = next(int_lat_generator) @ cell
297
        closest_dist_bound = np.linalg.norm(lattice, axis=-1) - max_cdist
298
        is_close = closest_dist_bound <= np.amax(dists_[:, -1])
299
        if not np.any(is_close):
300
            break
301
        cloud = np.vstack((cloud, _lattice_to_cloud(motif, lattice[is_close])))
302
        dists_, inds = KDTree(
303
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
304
        ).query(motif, k=cloud.shape[0])
305
306
    k = np.argwhere(np.all(dists >= min_val, axis=0))[0][0]
307
    return dists_[:, 1:k+1], cloud, inds
308