Passed
Push — master ( f996d7...a617b6 )
by Daniel
03:57
created

amd._nearest_neighbours._close_lattice_points()   A

Complexity

Conditions 5

Size

Total Lines 25
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 16
dl 0
loc 25
rs 9.1333
c 0
b 0
f 0
cc 5
nop 3
1
"""Implements core function nearest_neighbours used for AMD and PDD
2
calculations.
3
"""
4
5
from typing import Tuple, Iterable, Callable
6
from itertools import product, tee, accumulate
7
import functools
8
9
import numba
10
import numpy as np
11
from scipy.spatial import KDTree
12
from scipy.spatial.distance import cdist
13
14
__all__ = [
15
    'nearest_neighbours',
16
    'nearest_neighbours_data',
17
    'nearest_neighbours_minval',
18
    'generate_concentric_cloud'
19
]
20
21
22
def nearest_neighbours(
23
        motif: np.ndarray, cell: np.ndarray, x: np.ndarray, k: int
24
) -> np.ndarray:
25
    """Find distances to ``k`` nearest neighbours in a periodic set for
26
    each point in ``x``.
27
28
    Given a periodic set described by ``motif`` and ``cell``, a query
29
    set of points ``x`` and an integer ``k``, find distances to the
30
    ``k`` nearest neighbours in the periodic set for all points in
31
    ``x``. Returns an array with shape (x.shape[0], k) of distances to
32
    the neighbours. This function only returns distances, see the
33
    function nearest_neighbours_data() to also get the point cloud and
34
    indices of the points which are neighbours.
35
36
    Parameters
37
    ----------
38
    motif : :class:`numpy.ndarray`
39
        Cartesian coordinates of the motif, shape (no points, dims).
40
    cell : :class:`numpy.ndarray`
41
        The unit cell as a square array, shape (dims, dims).
42
    x : :class:`numpy.ndarray`
43
        Array of points to query for neighbours. For AMD/PDD invariants
44
        this is the motif, or more commonly an asymmetric unit of it.
45
    k : int
46
        Number of nearest neighbours to find for each point in ``x``.
47
48
    Returns
49
    -------
50
    dists : numpy.ndarray
51
        Array shape ``(x.shape[0], k)`` of distances from points in
52
        ``x`` to their ``k`` nearest neighbours in the periodic set in
53
        order, e.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
54
        n-th nearest neighbour in the periodic set.
55
    """
56
57
    m, dims = motif.shape
58
    # Get an initial collection of lattice points + a generator for more
59
    int_lattice, int_lat_generator = _integer_lattice_batches(dims, k / m)
60
    cloud = _int_lattice_to_cloud(motif, cell, int_lattice)
61
    # cloud = _lattice_to_cloud(motif, int_lattice @ cell)
62
63
    # Squared distances to k nearest neighbours
64
    sqdists = cdist(x, cloud, metric='sqeuclidean')
65
    motif_diam = np.sqrt(_max_in_first_n_columns(sqdists, m))
66
    sqdists.partition(k - 1)
67
    sqdists = sqdists[:, :k]
68
    sqdists.sort()
69
70
    # Generate layers of lattice until they are too far away to give nearer
71
    # neighbours. For a lattice point l, points in l + motif are further away
72
    # from x than |l| - max|p-p'| (p in x, p' in motif), giving a bound we can
73
    # use to rule out distant lattice points.
74
    max_sqd = _max_in_column(sqdists, k - 1)
75
    bound = (np.sqrt(max_sqd) + motif_diam) ** 2
76
77
    while True:
78
79
        # Get next layer of lattice
80
        lattice = _close_lattice_points(next(int_lat_generator), cell, bound)
81
        if lattice.size == 0:  # None are close enough
82
            break
83
84
        # Squared distances to new points
85
        cloud = _lattice_to_cloud(motif, lattice)
86
        sqdists_ = cdist(x, cloud, metric='sqeuclidean')
87
        close = sqdists_ < max_sqd
88
        if not np.any(close):  # None are close enough
89
            break
90
91
        # Squared distances to up to k nearest new points
92
        sqdists_ = sqdists_[:, np.any(close, axis=0)]
93
        if sqdists_.shape[-1] > k:
94
            sqdists_.partition(k - 1)
95
            sqdists_ = sqdists_[:, :k]
96
        sqdists_.sort()
97
98
        # Merge existing and new distances
99
        sqdists = _merge_dists(sqdists, sqdists_)
100
        max_sqd = _max_in_column(sqdists, k - 1)
101
        bound = (np.sqrt(max_sqd) + motif_diam) ** 2
102
103
    return np.sqrt(sqdists)
104
105
106
def nearest_neighbours_data(
107
        motif: np.ndarray, cell: np.ndarray, x: np.ndarray, k: int
108
) -> Tuple[np.ndarray, ...]:
109
    """Find the ``k`` nearest neighbours in a periodic set for each
110
    point in ``x``, along with data about those neighbours.
111
112
    Given a periodic set described by ``motif`` and ``cell``, a query
113
    set of points ``x`` and an integer ``k``, find the ``k`` nearest
114
    neighbours in the periodic set for all points in ``x``. Return
115
    an array of distances to neighbours, the point cloud generated
116
    during the search and the indices of which points in the cloud are
117
    the neighbours of points in ``x``.
118
119
    Note: the point ``cloud[i]`` in the periodic set comes from the
120
    motif point ``motif[i % len(motif)]``, because points are added to
121
    ``cloud`` in batches of whole unit cells and not rearranged.
122
123
    Parameters
124
    ----------
125
    motif : :class:`numpy.ndarray`
126
        Cartesian coordinates of the motif, shape (no points, dims).
127
    cell : :class:`numpy.ndarray`
128
        The unit cell as a square array, shape (dims, dims).
129
    x : :class:`numpy.ndarray`
130
        Array of points to query for neighbours. For AMD/PDD invariants
131
        this is the motif, or more commonly an asymmetric unit of it.
132
    k : int
133
        Number of nearest neighbours to find for each point in ``x``.
134
135
    Returns
136
    -------
137
    dists : numpy.ndarray
138
        Array shape ``(x.shape[0], k)`` of distances from points in
139
        ``x`` to their ``k`` nearest neighbours in the periodic set in
140
        order, e.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
141
        n-th nearest neighbour in the periodic set.
142
    cloud : numpy.ndarray
143
        Collection of points in the periodic set that were generated
144
        during the search. Arranged such that cloud[i] comes from the
145
        motif point motif[i % len(motif)] by translation. 
146
    inds : numpy.ndarray
147
        Array shape ``(x.shape[0], k)`` containing the indices of
148
        nearest neighbours in ``cloud``, e.g. the n-th nearest neighbour
149
        to ``x[m]`` is ``cloud[inds[m][n]]``.
150
    """
151
152
    full_cloud = []
153
    m, dims = motif.shape
154
    # Get an initial collection of lattice points + a generator for more
155
    int_lattice, int_lat_generator = _integer_lattice_batches(dims, k / m)
156
    # cloud = _int_lattice_to_cloud(motif, cell, int_lattice)
157
    cloud = _lattice_to_cloud(motif, int_lattice @ cell)
158
    full_cloud.append(cloud)
159
    cloud_ind_offset = len(cloud)
160
161
    # Squared distances to k nearest neighbours + inds of neighbours in cloud
162
    sqdists = cdist(x, cloud, metric='sqeuclidean')
163
    motif_diam = np.sqrt(_max_in_first_n_columns(sqdists, m))
164
    part_inds = np.argpartition(sqdists, k - 1)[:, :k]
165
    part_sqdists = np.take_along_axis(sqdists, part_inds, axis=-1)[:, :k]
166
    part_sort_inds = np.argsort(part_sqdists)
167
    inds = np.take_along_axis(part_inds, part_sort_inds, axis=-1)
168
    sqdists = np.take_along_axis(part_sqdists, part_sort_inds, axis=-1)
169
170
    # Generate layers of lattice until they are too far away to give nearer
171
    # neighbours. For a lattice point l, points in l + motif are further away
172
    # from x than |l| - max|p-p'| (p in x, p' in motif), giving a bound we can
173
    # use to rule out distant lattice points.
174
    max_sqd = _max_in_column(sqdists, k - 1)
175
    bound = (np.sqrt(max_sqd) + motif_diam) ** 2
176
177
    while True:
178
179
        # Get next layer of lattice
180
        lattice = _close_lattice_points(next(int_lat_generator), cell, bound)
181
        if lattice.size == 0:  # None are close enough
182
            break
183
184
        cloud = _lattice_to_cloud(motif, lattice)
185
        full_cloud.append(cloud)
186
        # Squared distances to new points
187
        sqdists_ = cdist(x, cloud, metric='sqeuclidean')
188
        close = sqdists_ < max_sqd
189
        if not np.any(close):  # None are close enough
190
            break
191
192
        # Squared distances to up to k nearest new points + inds
193
        part_inds = np.argpartition(sqdists_, k - 1)[:, :k]
194
        part_sqdists = np.take_along_axis(sqdists_, part_inds, axis=-1)[:, :k]
195
        part_sort_inds = np.argsort(part_sqdists)
196
        inds_ = np.take_along_axis(part_inds, part_sort_inds, axis=-1)
197
        sqdists_ = np.take_along_axis(part_sqdists, part_sort_inds, axis=-1)
198
199
        # Move inds_ so they point to full_cloud instead of cloud
200
        inds_ += cloud_ind_offset
201
        cloud_ind_offset += len(cloud)
202
203
        # Merge sqdists and sqdists_, and inds and inds_
204
        sqdists, inds = _merge_dists_inds(sqdists, sqdists_, inds, inds_)
205
        max_sqd = _max_in_column(sqdists, k - 1)
206
        bound = (np.sqrt(max_sqd) + motif_diam) ** 2
207
208
    return np.sqrt(sqdists), np.concatenate(full_cloud), inds
209
210
211
def _integer_lattice_batches_cache(integer_lattice_batches_func):
212
    """Specialised cache for ``_integer_lattice_batches``. Stores layers
213
    of integer lattice points from ``_generate_integer_lattice``
214
    already concatenated to avoid re-computing and concatenating them.
215
    How many layers are needed depends on the ratio of k to m, so this
216
    is passed to the cache. One cache is kept for each dimension.
217
    """
218
219
    cache = {}          # (dims, n_layers): [layers, generator]
220
    npoints_cache = {}  # dims: [#points in layer for layer=1,2,...]
221
222
    @functools.wraps(integer_lattice_batches_func)
223
    def wrapper(dims, k_over_m):
224
225
        if dims not in npoints_cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable npoints_cache does not seem to be defined.
Loading history...
226
            npoints_cache[dims] = []
227
228
        # Use npoints_cache to get how many layers of the lattice are needed
229
        n_layers = 1
230
        for n_points in npoints_cache[dims]:
231
            if n_points > k_over_m:
232
                break
233
            n_layers += 1
234
        n_layers += 1
235
236
        # Update cache and possibly npoints_cache
237
        if (dims, n_layers) not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
238
            layers, generator = integer_lattice_batches_func(dims, k_over_m)
239
            n_layers = len(layers)
240
241
            # If more layers were made than seen so far, update npoints_cache
242
            if len(npoints_cache[dims]) < n_layers:
243
                npoints_cache[dims] = list(accumulate(len(i) for i in layers))
244
            cache[(dims, n_layers)] = [np.concatenate(layers), generator]
245
246
        layers, generator = cache[(dims, n_layers)]
247
        cache[(dims, n_layers)][1], ret_generator = tee(generator)
248
249
        return layers, ret_generator
250
251
    return wrapper
252
253
254
@_integer_lattice_batches_cache
255
def _integer_lattice_batches(
256
        dims: int, k_over_m: float
257
) -> Tuple[np.ndarray, Iterable[np.ndarray]]:
258
    """Return an initial batch of integer lattice points (number
259
    according to k / m) and a generator for more distant points. Results
260
    are cached for speed.
261
262
    Parameters
263
    ----------
264
    dims : int
265
        The dimension of Euclidean space the lattice is in.
266
    k_over_m : float
267
        Number of nearest neighbours to find (parameter of
268
        nearest_neighbours, k) divided by the number of motif points in
269
        the periodic set (m). Used to determine how many layers of the
270
        integer lattice to return.
271
272
    Returns
273
    -------
274
    initial_integer_lattice : :class:`numpy.ndarray`
275
        A collection of integer lattice points. Consists of the first
276
        few layers generated by ``integer_lattice_generator`` (number of
277
        layers depends on m, k).
278
    integer_lattice_generator
279
        A generator for integer lattice points more distant than those
280
        in ``initial_integer_lattice``.
281
    """
282
283
    int_lattice_generator = iter(_generate_integer_lattice(dims))
284
    layers = [next(int_lattice_generator)]
285
    n_points = 1
286
    while n_points <= k_over_m:
287
        layer = next(int_lattice_generator)
288
        n_points += layer.shape[0]
289
        layers.append(layer)
290
    layers.append(next(int_lattice_generator))
291
    return layers, int_lattice_generator
292
293
294
def memoized_generator(generator_function: Callable) -> Callable:
295
    """Caches results of a generator."""
296
    cache = {}
297
    @functools.wraps(generator_function)
298
    def wrapper(*args) -> Iterable:
299
        if args not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
300
            cache[args] = generator_function(*args)
301
        cache[args], r = tee(cache[args])
302
        return r
303
    return wrapper
304
305
306
@memoized_generator
307
def _generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
308
    """Generate batches of integer lattice points. Each yield gives all
309
    points (that have not already been yielded) inside a sphere centered
310
    at the origin with radius d; d starts at 0 and increments by 1 on
311
    each loop.
312
313
    Parameters
314
    ----------
315
    dims : int
316
        The dimension of Euclidean space the lattice is in.
317
318
    Yields
319
    -------
320
    :class:`numpy.ndarray`
321
        Yields arrays of integer lattice points in `dims`-dimensional
322
        Euclidean space.
323
    """
324
325
    d = 0
326
    if dims == 1:
327
        yield np.zeros((1, 1), dtype=np.float64)
328
        while True:
329
            d += 1
330
            yield np.array([[-d], [d]], dtype=np.float64)
331
332
    ymax = {}
333
    while True:
334
        positive_int_lattice = []
335
        while True:
336
            batch = False
337
            for xy in product(range(d + 1), repeat=dims-1):
338
                if xy not in ymax:
339
                    ymax[xy] = 0
340
                if sum(i**2 for i in xy) + ymax[xy]**2 <= d**2:
341
                    positive_int_lattice.append((*xy, ymax[xy]))
342
                    batch = True
343
                    ymax[xy] += 1
344
            if not batch:
345
                break
346
        pos_int_lat = np.array(positive_int_lattice, dtype=np.float64)
347
        yield _reflect_positive_integer_lattice(pos_int_lat)
348
        d += 1
349
350
351
@numba.njit(cache=True, fastmath=True)
352
def _reflect_positive_integer_lattice(
353
        positive_int_lattice: np.ndarray
354
) -> np.ndarray:
355
    """Reflect points in the positive quadrant across all combinations
356
    of axes, without duplicating points that are invariant under
357
    reflections.
358
    """
359
360
    dims = positive_int_lattice.shape[-1]
361
    batches = []
362
    batches.extend(positive_int_lattice)
363
364
    for n_reflections in range(1, dims + 1):
365
366
        axes = np.arange(n_reflections)
367
        batches.extend(_reflect_in_axes(positive_int_lattice, axes))
368
369
        while True:
370
            i = n_reflections - 1
371
            for _ in range(n_reflections):
372
                if axes[i] != i + dims - n_reflections:
373
                    break
374
                i -= 1
375
            else:
376
                break
377
            axes[i] += 1
378
            for j in range(i + 1, n_reflections):
379
                axes[j] = axes[j-1] + 1
380
            batches.extend(_reflect_in_axes(positive_int_lattice, axes))
381
382
    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
383
    for i in range(len(batches)):
384
        int_lattice[i] = batches[i]
385
386
    return int_lattice
387
388
389
@numba.njit(cache=True, fastmath=True)
390
def _reflect_in_axes(
391
        positive_int_lattice: np.ndarray, axes: np.ndarray
392
) -> np.ndarray:
393
    """Reflect points in `positive_int_lattice` in the axes described by
394
    `axes`, without duplicating invariant points.
395
    """
396
    not_on_axes = (positive_int_lattice[:, axes] == 0).sum(axis=-1) == 0
397
    int_lattice = positive_int_lattice[not_on_axes]
398
    int_lattice[:, axes] *= -1
399
    return int_lattice
400
401
402
@numba.njit(cache=True, fastmath=True)
403
def _close_lattice_points(
404
        int_lattice: np.ndarray, cell: np.ndarray, bound: float
405
) -> np.ndarray:
406
    """Given integer lattice points, a unit cell and ``bound``, return
407
    lattice points which are close enough such that the corresponding
408
    motif copy could contain nearest neighbours. ``bound`` should be
409
    equal to (max_d + motif_diam) ** 2, where max_d is the maximum
410
    k-th nearest neighbour distance found so far and motif_diam is the
411
    largest distance between any point in the query set and motif.
412
    """
413
414
    lattice = int_lattice @ cell
415
    inds = []
416
    for i in range(len(lattice)):
417
        s = 0
418
        for xyz in lattice[i]:
419
            s += xyz ** 2
420
        if s < bound:
421
            inds.append(i)
422
    n_points = len(inds)
423
    ret = np.empty((n_points, cell.shape[0]), dtype=np.float64)
424
    for i in range(n_points):
425
        ret[i] = lattice[inds[i]]
426
    return ret
427
428
429
@numba.njit(cache=True, fastmath=True)
430
def _lattice_to_cloud(motif: np.ndarray, lattice: np.ndarray) -> np.ndarray:
431
    """Create a cloud of points from a periodic set with ``motif``,
432
    and a collection of lattice points ``lattice``. 
433
    """
434
    m = len(motif)
435
    layer = np.empty((m * len(lattice), motif.shape[-1]), dtype=np.float64)
436
    i1 = 0
437
    for translation in lattice:
438
        i2 = i1 + m
439
        layer[i1:i2] = motif + translation
440
        i1 = i2
441
    return layer
442
443
444
@numba.njit(cache=True, fastmath=True)
445
def _int_lattice_to_cloud(
446
        motif: np.ndarray, cell: np.ndarray, int_lattice: np.ndarray
447
) -> np.ndarray:
448
    """Create a cloud of points from a periodic set given by ``motif``
449
    and ``cell`` from a collection of integer lattice points
450
    ``int_lattice``.
451
    """
452
    return _lattice_to_cloud(motif, int_lattice @ cell)
453
454
455
@numba.njit(cache=True, fastmath=True)
456
def _max_in_column(arr: np.ndarray, col: int) -> float:
457
    """Return maximum value in chosen column col of array arr. Assumes
458
    all values of arr are non-negative."""
459
    ret = 0
460
    for i in range(arr.shape[0]):
461
        v = arr[i, col]
462
        if v > ret:
463
            ret = v
464
    return ret
465
466
467
@numba.njit(cache=True, fastmath=True)
468
def _max_in_first_n_columns(arr: np.ndarray, n: int) -> float:
469
    """Return maximum value in all columns up to a chosen column col of
470
    array arr."""
471
    ret = 0
472
    for col in range(n):
473
        v = _max_in_column(arr, col)
474
        if v > ret:
475
            ret = v
476
    return ret
477
478
479
@numba.njit(cache=True, fastmath=True)
480
def _merge_dists(dists: np.ndarray, dists_: np.ndarray) -> np.ndarray:
481
    """Merge two 2D arrays sorted along last axis into one sorted array
482
    with same number of columns as ``dists``. Optimised for the distance
483
    arrays in nearest_neighbours, where ``dists`` will contain most of
484
    the smallest elements and only a few values in later columns will
485
    need to be replaced with values in ``dists_``.
486
    """
487
488
    m, n_new_points = dists_.shape
489
    ret = np.copy(dists)
490
491
    for i in range(m):
492
493
        # Traverse row backwards until value smaller than dists_[i, 0]
494
        j = 0
495
        dp_ = 0
496
        d_ = dists_[i, dp_]
497
498
        while True:
499
            j -= 1
500
            if dists[i, j] <= d_:
501
                j += 1
502
                break
503
504
        # If dists_[i, 0] >= dists[i, -1], no need to insert
505
        if j == 0:
506
            continue
507
508
        # dp points to dists[i], dp_ points to dists_[i]
509
        # Fill ret with the larger dist, then increment pointers and repeat
510
        dp = j
511
        d = dists[i, dp]
512
513
        while j < 0:
514
515
            if d <= d_:
516
                ret[i, j] = d
517
                dp += 1
518
                d = dists[i, dp]
519
            else:
520
                ret[i, j] = d_
521
                dp_ += 1
522
                
523
                if dp_ < n_new_points:
524
                    d_ = dists_[i, dp_]
525
                else:  # ran out of points in dists_
526
                    d_ = np.inf
527
528
            j += 1
529
530
    return ret
531
532
533
@numba.njit(cache=True, fastmath=True)
534
def _merge_dists_inds(
535
        dists: np.ndarray,
536
        dists_: np.ndarray,
537
        inds: np.ndarray,
538
        inds_: np.ndarray
539
) -> Tuple[np.ndarray, np.ndarray]:
540
    """The same as _merge_dists, but also merges two arrays
541
    ``inds`` and ``inds_`` in the same pattern ``dists`` and ``dists_``
542
    are merged.
543
    """
544
545
    m, n_new_points = dists_.shape
546
    ret_dists = np.copy(dists)
547
    ret_inds = np.copy(inds)
548
549
    for i in range(m):
550
551
        j = 0
552
        dp_ = 0
553
        d_ = dists_[i, dp_]
554
        p_ = inds_[i, dp_]
555
556
        while True:
557
            j -= 1
558
            if dists[i, j] <= d_:
559
                j += 1
560
                break
561
562
        if j == 0:
563
            continue
564
565
        dp = j
566
        d = dists[i, dp]
567
        p = inds[i, dp]
568
569
        while j < 0:
570
571
            if d <= d_:
572
                ret_dists[i, j] = d
573
                ret_inds[i, j] = p
574
                dp += 1
575
                d = dists[i, dp]
576
                p = inds[i, dp]
577
            else:
578
                ret_dists[i, j] = d_
579
                ret_inds[i, j] = p_
580
                dp_ += 1
581
582
                if dp_ < n_new_points:
583
                    d_ = dists_[i, dp_]
584
                    p_ = inds_[i, dp_]
585
                else:
586
                    d_ = np.inf
587
588
            j += 1
589
590
    return ret_dists, ret_inds
591
592
593
def nearest_neighbours_minval(
594
        motif: np.ndarray, cell: np.ndarray, min_val: float
595
) -> Tuple[np.ndarray, ...]:
596
    """Return the same ``dists``/PDD matrix as ``nearest_neighbours``,
597
    but with enough columns such that all values in the last column are
598
    at least ``min_val``. Unlike ``nearest_neighbours``, does not take a
599
    query array ``x`` but only finds neighbours to motif points, and
600
    does not return the point cloud or indices of the nearest
601
    neighbours. Used in ``PDD_reconstructable``.
602
    
603
    TODO: this function should be updated in line with
604
    nearest_neighbours.
605
    """
606
607
    # Generate initial cloud of points from the periodic set
608
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
609
    int_lat_generator = iter(int_lat_generator)
610
    cloud = []
611
    for _ in range(3):
612
        cloud.append(_lattice_to_cloud(motif, next(int_lat_generator) @ cell))
613
    cloud = np.concatenate(cloud)
614
615
    # Find k neighbours in the point cloud for points in motif
616
    dists_, inds = KDTree(
617
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
618
    ).query(motif, k=cloud.shape[0])
619
    dists = np.zeros_like(dists_, dtype=np.float64)
620
621
    # Add layers & find k nearest neighbours until all distances smaller than
622
    # min_val don't change
623
    max_cdist = np.amax(cdist(motif, motif))
624
    while True:
625
626
        if np.all(dists_[:, -1] >= min_val):
627
            col = np.argwhere(np.all(dists_ >= min_val, axis=0))[0][0] + 1
628
            if np.array_equal(dists[:, :col], dists_[:, :col]):
629
                break
630
631
        dists = dists_
632
        lattice = next(int_lat_generator) @ cell
633
        closest_dist_bound = np.linalg.norm(lattice, axis=-1) - max_cdist
634
        is_close = closest_dist_bound <= np.amax(dists_[:, -1])
635
        if not np.any(is_close):
636
            break
637
638
        cloud = np.vstack((cloud, _lattice_to_cloud(motif, lattice[is_close])))
639
        dists_, inds = KDTree(
640
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
641
        ).query(motif, k=cloud.shape[0])
642
643
    k = np.argwhere(np.all(dists >= min_val, axis=0))[0][0]
644
    return dists_[:, 1:k+1], cloud, inds
645
646
647
def generate_concentric_cloud(
648
        motif: np.ndarray, cell: np.ndarray
649
) -> Iterable[np.ndarray]:
650
    """Generates batches of points from a periodic set given by (motif,
651
    cell) which get successively further away from the origin.
652
653
    Each yield gives all points (that have not already been yielded)
654
    which lie in a unit cell whose corner lattice point was generated by
655
    ``generate_integer_lattice(motif.shape[1])``.
656
657
    Parameters
658
    ----------
659
    motif : :class:`numpy.ndarray`
660
        Cartesian representation of the motif, shape (no points, dims).
661
    cell : :class:`numpy.ndarray`
662
        Cartesian representation of the unit cell, shape (dims, dims).
663
664
    Yields
665
    -------
666
    :class:`numpy.ndarray`
667
        Yields arrays of points from the periodic set.
668
    """
669
670
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
671
    for layer in int_lat_generator:
672
        yield _lattice_to_cloud(motif, layer @ cell)
673