Passed
Push — master ( 8c12c2...4daa36 )
by Daniel
07:46
created

amd._nearest_neighbours.memoized_generator()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
cc 2
nop 1
1
"""Implements core function nearest_neighbours used for AMD and PDD
2
calculations.
3
"""
4
5
from typing import Tuple, Iterable
6
from itertools import product, tee
7
import functools
8
9
import numba
10
import numpy as np
11
from scipy.spatial import KDTree
12
from scipy.spatial.distance import cdist
13
14
__all__ = [
15
    'nearest_neighbours',
16
    'nearest_neighbours_data',
17
    'nearest_neighbours_minval',
18
    'generate_concentric_cloud'
19
]
20
21
22
def nearest_neighbours(
23
        motif: np.ndarray, cell: np.ndarray, x: np.ndarray, k: int
24
) -> np.ndarray:
25
    """Find distances to ``k`` nearest neighbours in a periodic set for
26
    each point in ``x``.
27
28
    Given a periodic set described by ``motif`` and ``cell``, a query
29
    set of points ``x`` and an integer ``k``, find distances to the
30
    ``k`` nearest neighbours in the periodic set for all points in
31
    ``x``. Returns an array with shape (x.shape[0], k) of distances to
32
    the neighbours. This function only returns distances, see the
33
    function nearest_neighbours_data() to also get the point cloud and
34
    indices of the points which are neighbours.
35
36
    Parameters
37
    ----------
38
    motif : :class:`numpy.ndarray`
39
        Cartesian coordinates of the motif, shape (no points, dims).
40
    cell : :class:`numpy.ndarray`
41
        The unit cell as a square array, shape (dims, dims).
42
    x : :class:`numpy.ndarray`
43
        Array of points to query for neighbours. For AMD/PDD invariants
44
        this is the motif, or more commonly an asymmetric unit of it.
45
    k : int
46
        Number of nearest neighbours to find for each point in ``x``.
47
48
    Returns
49
    -------
50
    dists : numpy.ndarray
51
        Array shape ``(x.shape[0], k)`` of distances from points in
52
        ``x`` to their ``k`` nearest neighbours in the periodic set in
53
        order, e.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
54
        n-th nearest neighbour in the periodic set.
55
    """
56
57
    m, dims = motif.shape
58
    # Get an initial collection of lattice points + a generator for more
59
    int_lat_cloud, int_lat_generator = _get_integer_lattice(dims, m, k)
60
    cloud = _int_lattice_to_cloud(motif, cell, int_lat_cloud)
61
62
    # Squared distances to k nearest neighbours
63
    sqdists = _cdist_sqeuclidean(x, cloud)
64
    motif_diam = np.sqrt(_max_in_columns(sqdists, m))
65
    sqdists.partition(k - 1)
66
    sqdists = sqdists[:, :k]
67
    sqdists.sort()
68
69
    # Generate layers of lattice until they are too far away to give
70
    # nearer neighbours. For a lattice point l, points in l + motif are
71
    # further away from x than |l| - max|p-p'| (p in x, p' in motif),
72
    # giving a bound we can use to rule out distant lattice points
73
    max_sqd = np.amax(sqdists[:, -1])
74
    bound = (np.sqrt(max_sqd) + motif_diam) ** 2
75
76
    while True:
77
78
        # Get next layer of lattice
79
        lattice = _close_lattice_points(next(int_lat_generator), cell, bound)
80
        if lattice.size == 0:  # None are close enough
81
            break
82
83
        # Squared distances to new points
84
        sqdists_ = _cdist_sqeuclidean(x, _lattice_to_cloud(motif, lattice))
85
        close = sqdists_ < max_sqd
86
        if not np.any(close):  # None are close enough
87
            break
88
89
        # Squared distances to up to k nearest new points
90
        sqdists_ = sqdists_[:, np.any(close, axis=0)]
91
        if sqdists_.shape[-1] > k:
92
            sqdists_.partition(k - 1)
93
        sqdists_ = sqdists_[:, :k]
94
        sqdists_.sort()
95
96
        # Merge existing and new distances
97
        sqdists = _merge_sorted_arrays(sqdists, sqdists_)
98
        max_sqd = np.amax(sqdists[:, -1])
99
        bound = (np.sqrt(max_sqd) + motif_diam) ** 2
100
101
    return np.sqrt(sqdists)
102
103
104
def nearest_neighbours_data(
105
        motif: np.ndarray, cell: np.ndarray, x: np.ndarray, k: int
106
) -> np.ndarray:
107
    """Find the ``k`` nearest neighbours in a periodic set for each
108
    point in ``x``.
109
110
    Given a periodic set described by ``motif`` and ``cell``, a query
111
    set of points ``x`` and an integer ``k``, find the ``k`` nearest
112
    neighbours in the periodic set for all points in ``x``. Return
113
    an array of distances to neighbours, the point cloud generated
114
    during the search and the indices of which points in the cloud are
115
    the neighbours of points in ``x``.
116
117
    Parameters
118
    ----------
119
    motif : :class:`numpy.ndarray`
120
        Cartesian coordinates of the motif, shape (no points, dims).
121
    cell : :class:`numpy.ndarray`
122
        The unit cell as a square array, shape (dims, dims).
123
    x : :class:`numpy.ndarray`
124
        Array of points to query for neighbours. For AMD/PDD invariants
125
        this is the motif, or more commonly an asymmetric unit of it.
126
    k : int
127
        Number of nearest neighbours to find for each point in ``x``.
128
129
    Returns
130
    -------
131
    dists : numpy.ndarray
132
        Array shape ``(x.shape[0], k)`` of distances from points in
133
        ``x`` to their ``k`` nearest neighbours in the periodic set in
134
        order, e.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
135
        n-th nearest neighbour in the periodic set.
136
    cloud : numpy.ndarray
137
        Collection of points in the periodic set that were generated
138
        during the search.
139
    inds : numpy.ndarray
140
        Array shape ``(x.shape[0], k)`` containing the indices of
141
        nearest neighbours in ``cloud``, e.g. the n-th nearest neighbour
142
        to ``x[m]`` is ``cloud[inds[m][n]]``.
143
    """
144
145
    m, dims = motif.shape
146
    int_lat, int_lat_gen = _get_integer_lattice(dims, m, k)
147
    cloud = _int_lattice_to_cloud(motif, cell, int_lat)
148
    dists = cdist(x, cloud)
149
    motif_diam = _max_in_columns(dists, m)
150
    inds = np.argsort(dists)[:, :k]
151
    dists = np.take_along_axis(dists, inds, -1)
152
    b = (np.amax(dists[:, -1]) + motif_diam) ** 2
153
154
    while True:
155
156
        lattice = _close_lattice_points(next(int_lat_gen), cell, b)
157
        if lattice.size == 0:
158
            break
159
160
        cloud = np.concatenate((cloud, _lattice_to_cloud(motif, lattice)))
161
        dists = cdist(x, cloud)
162
        inds = np.argsort(dists)[:, :k]
163
        dists = np.take_along_axis(dists, inds, -1)
164
        b = (np.amax(dists[:, -1]) + motif_diam) ** 2
165
166
    return dists, cloud, inds
167
168
169
def _get_integer_lattice_cache(f):
170
    """Specialised cache for ``_get_integer_lattice()``."""
171
172
    cache = {}
173
    num_points_cache = {}
174
175
    @functools.wraps(f)
176
    def wrapper(dims, m, k):
177
178
        if dims not in num_points_cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable num_points_cache does not seem to be defined.
Loading history...
179
            num_points_cache[dims] = []
180
181
        n_points = 0
182
        n_layers = 0
183
        within_cache = False
184
        for num_p in num_points_cache[dims]:
185
            if n_points > k / m:
186
                within_cache = True
187
                break
188
            n_points += num_p
189
            n_layers += 1
190
        n_layers += 1
191
192
        if not (within_cache and (dims, n_layers) in cache):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
193
            layers, int_lat_generator = f(dims, m, k)
194
            n_layers = len(layers)
195
            if len(num_points_cache[dims]) < n_layers:
196
                num_points_cache[dims] = [len(i) for i in layers]
197
            layers = np.concatenate(layers)
198
            cache[(dims, n_layers)] = [layers, int_lat_generator]
199
200
        arr, g = cache[(dims, n_layers)]
201
        cache[(dims, n_layers)][1], r = tee(g)
202
        return arr, r
203
204
    return wrapper
205
206
207
@_get_integer_lattice_cache
208
def _get_integer_lattice(dims, m, k):
209
    """Return an initial batch of integer lattice points (number
210
    according to m and k) and a generator for more distant points.
211
212
    Parameters
213
    ----------
214
    dims : int
215
        The dimension of Euclidean space the lattice is in.
216
    m : int
217
        Number of motif points.
218
    k : int
219
        Number of nearest neighbours to find (parameter of
220
        nearest_neighbours).
221
222
    Returns
223
    -------
224
    initial_integer_lattice : :class:`numpy.ndarray`
225
        A collection of integer lattice points. Consists of the first
226
        few layers generated by ``integer_lattice_generator`` (number of
227
        layers depends on m, k).
228
    integer_lattice_generator
229
        A generator for integer lattice points more distant than those
230
        in ``initial_integer_lattice``.
231
    """
232
233
    g = iter(_generate_integer_lattice(dims))
234
    layers = [next(g)]
235
    n_points = 1
236
    while n_points <= k / m:
237
        layer = next(g)
238
        n_points += layer.shape[0]
239
        layers.append(layer)
240
    layers.append(next(g))
241
    return layers, g
242
243
244
def memoized_generator(f):
245
    """Caches results of a generator."""
246
    cache = {}
247
    @functools.wraps(f)
248
    def wrapper(*args):
249
        if args not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
250
            cache[args] = f(*args)
251
        cache[args], r = tee(cache[args])
252
        return r
253
    return wrapper
254
255
256
@memoized_generator
257
def _generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
258
    """Generate batches of integer lattice points. Each yield gives all
259
    points (that have not already been yielded) inside a sphere centered
260
    at the origin with radius d; d starts at 0 and increments by 1 on
261
    each loop.
262
263
    Parameters
264
    ----------
265
    dims : int
266
        The dimension of Euclidean space the lattice is in.
267
268
    Yields
269
    -------
270
    :class:`numpy.ndarray`
271
        Yields arrays of integer points in `dims`-dimensional Euclidean
272
        space.
273
    """
274
275
    d = 0
276
    if dims == 1:
277
        yield np.zeros((1, 1), dtype=np.float64)
278
        while True:
279
            d += 1
280
            yield np.array([[-d], [d]], dtype=np.float64)
281
282
    ymax = {}
283
    while True:
284
        positive_int_lattice = []
285
        while True:
286
            batch = False
287
            for xy in product(range(d + 1), repeat=dims-1):
288
                if xy not in ymax:
289
                    ymax[xy] = 0
290
                if sum(i**2 for i in xy) + ymax[xy]**2 <= d**2:
291
                    positive_int_lattice.append((*xy, ymax[xy]))
292
                    batch = True
293
                    ymax[xy] += 1
294
            if not batch:
295
                break
296
        pos_int_lat = np.array(positive_int_lattice, dtype=np.float64)
297
        yield _reflect_positive_integer_lattice(pos_int_lat)
298
        d += 1
299
300
301
@numba.njit(cache=True, fastmath=True)
302
def _reflect_positive_integer_lattice(
303
        positive_int_lattice: np.ndarray
304
) -> np.ndarray:
305
    """Reflect points in the positive quadrant across all combinations
306
    of axes, without duplicating points that are invariant under
307
    reflections.
308
    """
309
310
    dims = positive_int_lattice.shape[-1]
311
    batches = []
312
    batches.extend(positive_int_lattice)
313
314
    for n_reflections in range(1, dims + 1):
315
316
        axes = np.arange(n_reflections)
317
        batches.extend(_reflect_in_axes(positive_int_lattice, axes))
318
319
        while True:
320
            i = n_reflections - 1
321
            for _ in range(n_reflections):
322
                if axes[i] != i + dims - n_reflections:
323
                    break
324
                i -= 1
325
            else:
326
                break
327
            axes[i] += 1
328
            for j in range(i + 1, n_reflections):
329
                axes[j] = axes[j-1] + 1
330
            batches.extend(_reflect_in_axes(positive_int_lattice, axes))
331
332
    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
333
    for i in range(len(batches)):
334
        int_lattice[i] = batches[i]
335
336
    return int_lattice
337
338
339
@numba.njit(cache=True, fastmath=True)
340
def _reflect_in_axes(
341
        positive_int_lattice: np.ndarray, axes: np.ndarray
342
) -> np.ndarray:
343
    """Reflect points in `positive_int_lattice` in the axes described by
344
    `axes`, without duplicating invariant points.
345
    """
346
    not_on_axes = (positive_int_lattice[:, axes] == 0).sum(axis=-1) == 0
347
    int_lattice = positive_int_lattice[not_on_axes]
348
    int_lattice[:, axes] *= -1
349
    return int_lattice
350
351
352
@numba.njit(cache=True, fastmath=True)
353
def _close_lattice_points(
354
        int_lattice: np.ndarray, cell: np.ndarray, bound: float
355
) -> np.ndarray:
356
    """Given integer lattice points, a unit cell and ``bound``, return
357
    lattice points which are close enough such that the corresponding
358
    motif copy could contain nearest neighbours. ``bound`` should be
359
    equal to (max_d + motif_diam) ** 2, where max_d is the maximum
360
    k-th nearest neighbour distance found so far and motif_diam is the
361
    largest distance between any point in the query set and motif.
362
    """
363
364
    lattice = int_lattice @ cell
365
    inds = []
366
    for i in range(len(lattice)):
367
        s = 0
368
        for xyz in lattice[i]:
369
            s += xyz ** 2
370
        if s < bound:
371
            inds.append(i)
372
    ret = np.empty((len(inds), lattice.shape[-1]), dtype=np.float64)
373
    for i in range(len(inds)):
374
        ret[i] = lattice[inds[i]]
375
    return ret
376
377
378
@numba.njit(cache=True, fastmath=True)
379
def _lattice_to_cloud(motif: np.ndarray, lattice: np.ndarray) -> np.ndarray:
380
    """Transform a batch of lattice points (generated by
381
    _generate_integer_lattice then mutliplied by the cell) into a cloud
382
    of points from a periodic set.
383
    """
384
385
    m = len(motif)
386
    layer = np.empty((m * len(lattice), motif.shape[-1]), dtype=np.float64)
387
    i1 = 0
388
    for translation in lattice:
389
        i2 = i1 + m
390
        layer[i1:i2] = motif + translation
391
        i1 = i2
392
    return layer
393
394
395
@numba.njit(cache=True, fastmath=True)
396
def _int_lattice_to_cloud(
397
        motif: np.ndarray, cell: np.ndarray, int_lattice: np.ndarray
398
) -> np.ndarray:
399
    """Transform a batch of integer lattice points (generated by
400
    _generate_integer_lattice) into a cloud of points from a periodic
401
    set.
402
    """
403
    return _lattice_to_cloud(motif, int_lattice @ cell)
404
405
406
@numba.njit(cache=True, fastmath=True)
407
def _cdist_sqeuclidean(arr1, arr2):
408
    """Squared Euclidean distance between points in ``arr1`` and
409
    ``arr2``."""
410
    n1, n2 = arr1.shape[0], arr2.shape[0]
411
    res = np.empty((n1, n2), dtype=np.float64)
412
    for i in range(n1):
413
        for j in range(n2):
414
            s = 0.
415
            for n in range(arr1.shape[-1]):
416
                s += (arr1[i, n] - arr2[j, n]) ** 2
417
            res[i, j] = s
418
    return res
419
420
421
@numba.njit(cache=True, fastmath=True)
422
def _max_in_column(arr, col):
423
    """Return maximum value in chosen column col of array arr."""
424
    ret = 0
425
    for i in range(arr.shape[0]):
426
        v = arr[i, col]
427
        if v > ret:
428
            ret = v
429
    return ret
430
431
432
@numba.njit(cache=True, fastmath=True)
433
def _max_in_columns(arr, max_col):
434
    """Return maximum value in all columns up to a chosen column col of
435
    array arr."""
436
    ret = 0
437
    for col in range(max_col):
438
        v = _max_in_column(arr, col)
439
        if v > ret:
440
            ret = v
441
    return ret
442
443
444
@numba.njit(cache=True, fastmath=True)
445
def _merge_sorted_arrays(dists, dists_):
446
    """Merge two 2D arrays sorted along last axis into one sorted array
447
    with same number of columns as ``dists``. Optimised for the distance
448
    arrays in nearest_neighbours, where ``dists`` will contain most of
449
    the smallest elements and only a few values in later columns will
450
    need to be replaced with values in ``dists_``.
451
    """
452
453
    m, n_new_points = dists_.shape
454
    ret = np.copy(dists)
455
456
    for i in range(m):
457
        # Traverse row backwards until value smaller than dists_[i, 0]
458
        j = 0
459
        dp_ = 0
460
        d_ = dists_[i, dp_]
461
        while True:
462
            j -= 1
463
            if dists[i, j] <= d_:
464
                j += 1
465
                break
466
467
        if j == 0:  # If dists_[i, 0] >= dists[i, -1], no need to insert
468
            continue
469
470
        # dp points to dists[i], dp_ points to dists_[i].
471
        # fill ret with the larger dist, then increment pointers and repeat.
472
        dp = j
473
        d = dists[i, dp]
474
475
        while j < 0:
476
            if d <= d_:
477
                ret[i, j] = d
478
                dp += 1
479
                d = dists[i, dp]
480
            else:
481
                ret[i, j] = d_
482
                dp_ += 1
483
                if dp_ < n_new_points:
484
                    d_ = dists_[i, dp_]
485
                else:  # ran out of points in dists_
486
                    d_ = np.inf
487
            j += 1
488
489
    return ret
490
491
492
def nearest_neighbours_minval(
493
        motif: np.ndarray, cell: np.ndarray, min_val: float
494
) -> Tuple[np.ndarray, ...]:
495
    """Return the same ``dists``/PDD matrix as ``nearest_neighbours``,
496
    but with enough columns such that all values in the last column are
497
    at least ``min_val``. Unlike ``nearest_neighbours``, does not take a
498
    query array ``x`` but only finds neighbours to motif points, and
499
    does not return the point cloud or indices of the nearest
500
    neighbours. Used in ``PDD_reconstructable``.
501
    
502
    TODO: this function should be updated in line with
503
    nearest_neighbours.
504
    """
505
506
    # Generate initial cloud of points from the periodic set
507
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
508
    int_lat_generator = iter(int_lat_generator)
509
    cloud = []
510
    for _ in range(3):
511
        cloud.append(_lattice_to_cloud(motif, next(int_lat_generator) @ cell))
512
    cloud = np.concatenate(cloud)
513
514
    # Find k neighbours in the point cloud for points in motif
515
    dists_, inds = KDTree(
516
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
517
    ).query(motif, k=cloud.shape[0])
518
    dists = np.zeros_like(dists_, dtype=np.float64)
519
520
    # Add layers & find k nearest neighbours until all distances smaller than
521
    # min_val don't change
522
    max_cdist = np.amax(cdist(motif, motif))
523
    while True:
524
        if np.all(dists_[:, -1] >= min_val):
525
            col = np.argwhere(np.all(dists_ >= min_val, axis=0))[0][0] + 1
526
            if np.array_equal(dists[:, :col], dists_[:, :col]):
527
                break
528
        dists = dists_
529
        lattice = next(int_lat_generator) @ cell
530
        closest_dist_bound = np.linalg.norm(lattice, axis=-1) - max_cdist
531
        is_close = closest_dist_bound <= np.amax(dists_[:, -1])
532
        if not np.any(is_close):
533
            break
534
        cloud = np.vstack((cloud, _lattice_to_cloud(motif, lattice[is_close])))
535
        dists_, inds = KDTree(
536
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
537
        ).query(motif, k=cloud.shape[0])
538
539
    k = np.argwhere(np.all(dists >= min_val, axis=0))[0][0]
540
    return dists_[:, 1:k+1], cloud, inds
541
542
543
def generate_concentric_cloud(motif, cell):
544
    """Generates batches of points from a periodic set given by (motif,
545
    cell) which get successively further away from the origin.
546
547
    Each yield gives all points (that have not already been yielded)
548
    which lie in a unit cell whose corner lattice point was generated by
549
    ``generate_integer_lattice(motif.shape[1])``.
550
551
    Parameters
552
    ----------
553
    motif : :class:`numpy.ndarray`
554
        Cartesian representation of the motif, shape (no points, dims).
555
    cell : :class:`numpy.ndarray`
556
        Cartesian representation of the unit cell, shape (dims, dims).
557
558
    Yields
559
    -------
560
    :class:`numpy.ndarray`
561
        Yields arrays of points from the periodic set.
562
    """
563
564
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
565
    for layer in int_lat_generator:
566
        yield _lattice_to_cloud(motif, layer @ cell)
567