Test Failed
Branch master (d8cf15)
by Daniel
03:54
created

amd._nearest_neighbours.nearest_neighbours()   B

Complexity

Conditions 5

Size

Total Lines 80
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 29
dl 0
loc 80
rs 8.7173
c 0
b 0
f 0
cc 5
nop 4

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Implements core function nearest_neighbours used for AMD and PDD
2
calculations.
3
"""
4
5
from typing import Tuple, Iterable
6
from itertools import product, tee
7
import functools
8
9
import numba
10
import numpy as np
11
from scipy.spatial import KDTree
12
from scipy.spatial.distance import cdist
13
14
__all__ = [
15
    'nearest_neighbours',
16
    'nearest_neighbours_data',
17
    'nearest_neighbours_minval',
18
    'generate_concentric_cloud'
19
]
20
21
22
def nearest_neighbours(
23
        motif: np.ndarray, cell: np.ndarray, x: np.ndarray, k: int
24
) -> np.ndarray:
25
    """Find distances to ``k`` nearest neighbours in a periodic set for
26
    each point in ``x``.
27
28
    Given a periodic set described by ``motif`` and ``cell``, a query
29
    set of points ``x`` and an integer ``k``, find distances to the
30
    ``k`` nearest neighbours in the periodic set for all points in
31
    ``x``. Returns an array with shape (x.shape[0], k) of distances to
32
    the neighbours. This function only returns distances, see the
33
    function nearest_neighbours_data() to also get the point cloud and
34
    indices of the points which are neighbours.
35
36
    Parameters
37
    ----------
38
    motif : :class:`numpy.ndarray`
39
        Cartesian coordinates of the motif, shape (no points, dims).
40
    cell : :class:`numpy.ndarray`
41
        The unit cell as a square array, shape (dims, dims).
42
    x : :class:`numpy.ndarray`
43
        Array of points to query for neighbours. For AMD/PDD invariants
44
        this is the motif, or more commonly an asymmetric unit of it.
45
    k : int
46
        Number of nearest neighbours to find for each point in ``x``.
47
48
    Returns
49
    -------
50
    dists : numpy.ndarray
51
        Array shape ``(x.shape[0], k)`` of distances from points in
52
        ``x`` to their ``k`` nearest neighbours in the periodic set in
53
        order, e.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
54
        n-th nearest neighbour in the periodic set.
55
    """
56
57
    m, dims = motif.shape
58
    # Get an initial collection of lattice points + a generator for more
59
    int_lattice, int_lat_generator = _integer_lattice_batches(dims, m, k)
60
    cloud = _int_lattice_to_cloud(motif, cell, int_lattice)
61
62
    # Squared distances to k nearest neighbours
63
    sqdists = _cdist_sqeuclidean(x, cloud)
64
    motif_diam = np.sqrt(_max_in_columns(sqdists, m))
65
    sqdists.partition(k - 1)
66
    sqdists = sqdists[:, :k]
67
    sqdists.sort()
68
69
    # Generate layers of lattice until they are too far away to give nearer
70
    # neighbours. For a lattice point l, points in l + motif are further away
71
    # from x than |l| - max|p-p'| (p in x, p' in motif), giving a bound we can
72
    # use to rule out distant lattice points.
73
    max_sqd = np.amax(sqdists[:, -1])
74
    bound = (np.sqrt(max_sqd) + motif_diam) ** 2
75
76
    while True:
77
78
        # Get next layer of lattice
79
        lattice = _close_lattice_points(next(int_lat_generator), cell, bound)
80
        if lattice.size == 0:  # None are close enough
81
            break
82
83
        # Squared distances to new points
84
        sqdists_ = _cdist_sqeuclidean(x, _lattice_to_cloud(motif, lattice))
85
        close = sqdists_ < max_sqd
86
        if not np.any(close):  # None are close enough
87
            break
88
89
        # Squared distances to up to k nearest new points
90
        sqdists_ = sqdists_[:, np.any(close, axis=0)]
91
        if sqdists_.shape[-1] > k:
92
            sqdists_.partition(k - 1)
93
        sqdists_ = sqdists_[:, :k]
94
        sqdists_.sort()
95
96
        # Merge existing and new distances
97
        sqdists = _merge_sorted_arrays(sqdists, sqdists_)
98
        max_sqd = np.amax(sqdists[:, -1])
99
        bound = (np.sqrt(max_sqd) + motif_diam) ** 2
100
101
    return np.sqrt(sqdists)
102
103
104
def nearest_neighbours_data(
105
        motif: np.ndarray, cell: np.ndarray, x: np.ndarray, k: int
106
) -> np.ndarray:
107
    """Find the ``k`` nearest neighbours in a periodic set for each
108
    point in ``x``, along with data about those neighbours.
109
110
    Given a periodic set described by ``motif`` and ``cell``, a query
111
    set of points ``x`` and an integer ``k``, find the ``k`` nearest
112
    neighbours in the periodic set for all points in ``x``. Return
113
    an array of distances to neighbours, the point cloud generated
114
    during the search and the indices of which points in the cloud are
115
    the neighbours of points in ``x``.
116
117
    Note: the point ``cloud[i]`` in the periodic set comes from the
118
    motif point ``motif[i % len(motif)]``, because points are added to
119
    ``cloud`` in batches of whole unit cells and not rearranged.
120
121
    Parameters
122
    ----------
123
    motif : :class:`numpy.ndarray`
124
        Cartesian coordinates of the motif, shape (no points, dims).
125
    cell : :class:`numpy.ndarray`
126
        The unit cell as a square array, shape (dims, dims).
127
    x : :class:`numpy.ndarray`
128
        Array of points to query for neighbours. For AMD/PDD invariants
129
        this is the motif, or more commonly an asymmetric unit of it.
130
    k : int
131
        Number of nearest neighbours to find for each point in ``x``.
132
133
    Returns
134
    -------
135
    dists : numpy.ndarray
136
        Array shape ``(x.shape[0], k)`` of distances from points in
137
        ``x`` to their ``k`` nearest neighbours in the periodic set in
138
        order, e.g. ``dists[m][n]`` is the distance from ``x[m]`` to its
139
        n-th nearest neighbour in the periodic set.
140
    cloud : numpy.ndarray
141
        Collection of points in the periodic set that were generated
142
        during the search. Arranged such that cloud[i] comes from the
143
        motif point motif[i % len(motif)] by translation. 
144
    inds : numpy.ndarray
145
        Array shape ``(x.shape[0], k)`` containing the indices of
146
        nearest neighbours in ``cloud``, e.g. the n-th nearest neighbour
147
        to ``x[m]`` is ``cloud[inds[m][n]]``.
148
    """
149
150
    full_cloud = []
151
    m, dims = motif.shape
152
    # Get an initial collection of lattice points + a generator for more
153
    int_lattice, int_lat_generator = _integer_lattice_batches(dims, m, k)
154
    cloud = _int_lattice_to_cloud(motif, cell, int_lattice)
155
    full_cloud.append(cloud)
156
    cloud_ind_offset = len(cloud)
157
158
    # Squared distances to k nearest neighbours + inds of neighbours in cloud
159
    sqdists = _cdist_sqeuclidean(x, cloud)
160
    motif_diam = np.sqrt(_max_in_columns(sqdists, m))
161
    part_inds = np.argpartition(sqdists, k - 1)[:, :k]
162
    part_sqdists = np.take_along_axis(sqdists, part_inds, axis=-1)[:, :k]
163
    part_sort_inds = np.argsort(part_sqdists)
164
    inds = np.take_along_axis(part_inds, part_sort_inds, axis=-1)
165
    sqdists = np.take_along_axis(part_sqdists, part_sort_inds, axis=-1)
166
167
    # Generate layers of lattice until they are too far away to give nearer
168
    # neighbours. For a lattice point l, points in l + motif are further away
169
    # from x than |l| - max|p-p'| (p in x, p' in motif), giving a bound we can
170
    # use to rule out distant lattice points.
171
    max_sqd = np.amax(sqdists[:, -1])
172
    bound = (np.sqrt(max_sqd) + motif_diam) ** 2
173
174
    while True:
175
176
        # Get next layer of lattice
177
        lattice = _close_lattice_points(next(int_lat_generator), cell, bound)
178
        if lattice.size == 0:  # None are close enough
179
            break
180
181
        cloud = _lattice_to_cloud(motif, lattice)
182
        full_cloud.append(cloud)
183
        # Squared distances to new points
184
        sqdists_ = _cdist_sqeuclidean(x, cloud)
185
        close = sqdists_ < max_sqd
186
        if not np.any(close):  # None are close enough
187
            break
188
189
        # Squared distances to up to k nearest new points + inds
190
        part_inds = np.argpartition(sqdists_, k - 1)[:, :k]
191
        part_sqdists = np.take_along_axis(sqdists_, part_inds, axis=-1)[:, :k]
192
        part_sort_inds = np.argsort(part_sqdists)
193
        inds_ = np.take_along_axis(part_inds, part_sort_inds, axis=-1)
194
        sqdists_ = np.take_along_axis(part_sqdists, part_sort_inds, axis=-1)
195
196
        # Move inds_ so they point to full_cloud instead of cloud
197
        inds_ += cloud_ind_offset
198
        cloud_ind_offset += len(cloud)
199
200
        # Merge sqdists and sqdists_, and inds and inds_
201
        sqdists, inds = _merge_sorted_arrays_inds(sqdists, sqdists_, inds, inds_)
202
        max_sqd = np.amax(sqdists[:, -1])
203
        bound = (np.sqrt(max_sqd) + motif_diam) ** 2
204
205
    return np.sqrt(sqdists), np.concatenate(full_cloud), inds
206
207
208
def _integer_lattice_batches_cache(f):
209
    """Specialised cache for ``_integer_lattice_batches()``."""
210
211
    cache = {}
212
    num_points_cache = {}
213
214
    @functools.wraps(f)
215
    def wrapper(dims, m, k):
216
217
        if dims not in num_points_cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable num_points_cache does not seem to be defined.
Loading history...
218
            num_points_cache[dims] = []
219
220
        n_points = 0
221
        n_layers = 0
222
        within_cache = False
223
        for num_p in num_points_cache[dims]:
224
            if n_points > k / m:
225
                within_cache = True
226
                break
227
            n_points += num_p
228
            n_layers += 1
229
        n_layers += 1
230
231
        if not (within_cache and (dims, n_layers) in cache):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
232
            layers, int_lat_generator = f(dims, m, k)
233
            n_layers = len(layers)
234
            if len(num_points_cache[dims]) < n_layers:
235
                num_points_cache[dims] = [len(i) for i in layers]
236
            layers = np.concatenate(layers)
237
            cache[(dims, n_layers)] = [layers, int_lat_generator]
238
239
        arr, g = cache[(dims, n_layers)]
240
        cache[(dims, n_layers)][1], r = tee(g)
241
        return arr, r
242
243
    return wrapper
244
245
246
@_integer_lattice_batches_cache
247
def _integer_lattice_batches(dims, m, k):
248
    """Return an initial batch of integer lattice points (number
249
    according to m and k) and a generator for more distant points.
250
251
    Parameters
252
    ----------
253
    dims : int
254
        The dimension of Euclidean space the lattice is in.
255
    m : int
256
        Number of motif points.
257
    k : int
258
        Number of nearest neighbours to find (parameter of
259
        nearest_neighbours).
260
261
    Returns
262
    -------
263
    initial_integer_lattice : :class:`numpy.ndarray`
264
        A collection of integer lattice points. Consists of the first
265
        few layers generated by ``integer_lattice_generator`` (number of
266
        layers depends on m, k).
267
    integer_lattice_generator
268
        A generator for integer lattice points more distant than those
269
        in ``initial_integer_lattice``.
270
    """
271
272
    int_lattice_generator = iter(_generate_integer_lattice(dims))
273
    layers = [next(int_lattice_generator)]
274
    n_points = 1
275
    while n_points <= k / m:
276
        layer = next(int_lattice_generator)
277
        n_points += layer.shape[0]
278
        layers.append(layer)
279
    layers.append(next(int_lattice_generator))
280
    return layers, int_lattice_generator
281
282
283
def memoized_generator(f):
284
    """Caches results of a generator."""
285
    cache = {}
286
    @functools.wraps(f)
287
    def wrapper(*args):
288
        if args not in cache:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable cache does not seem to be defined.
Loading history...
289
            cache[args] = f(*args)
290
        cache[args], r = tee(cache[args])
291
        return r
292
    return wrapper
293
294
295
@memoized_generator
296
def _generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
297
    """Generate batches of integer lattice points. Each yield gives all
298
    points (that have not already been yielded) inside a sphere centered
299
    at the origin with radius d; d starts at 0 and increments by 1 on
300
    each loop.
301
302
    Parameters
303
    ----------
304
    dims : int
305
        The dimension of Euclidean space the lattice is in.
306
307
    Yields
308
    -------
309
    :class:`numpy.ndarray`
310
        Yields arrays of integer lattice points in `dims`-dimensional
311
        Euclidean space.
312
    """
313
314
    d = 0
315
    if dims == 1:
316
        yield np.zeros((1, 1), dtype=np.float64)
317
        while True:
318
            d += 1
319
            yield np.array([[-d], [d]], dtype=np.float64)
320
321
    ymax = {}
322
    while True:
323
        positive_int_lattice = []
324
        while True:
325
            batch = False
326
            for xy in product(range(d + 1), repeat=dims-1):
327
                if xy not in ymax:
328
                    ymax[xy] = 0
329
                if sum(i**2 for i in xy) + ymax[xy]**2 <= d**2:
330
                    positive_int_lattice.append((*xy, ymax[xy]))
331
                    batch = True
332
                    ymax[xy] += 1
333
            if not batch:
334
                break
335
        pos_int_lat = np.array(positive_int_lattice, dtype=np.float64)
336
        yield _reflect_positive_integer_lattice(pos_int_lat)
337
        d += 1
338
339
340
@numba.njit(cache=True, fastmath=True)
341
def _reflect_positive_integer_lattice(
342
        positive_int_lattice: np.ndarray
343
) -> np.ndarray:
344
    """Reflect points in the positive quadrant across all combinations
345
    of axes, without duplicating points that are invariant under
346
    reflections.
347
    """
348
349
    dims = positive_int_lattice.shape[-1]
350
    batches = []
351
    batches.extend(positive_int_lattice)
352
353
    for n_reflections in range(1, dims + 1):
354
355
        axes = np.arange(n_reflections)
356
        batches.extend(_reflect_in_axes(positive_int_lattice, axes))
357
358
        while True:
359
            i = n_reflections - 1
360
            for _ in range(n_reflections):
361
                if axes[i] != i + dims - n_reflections:
362
                    break
363
                i -= 1
364
            else:
365
                break
366
            axes[i] += 1
367
            for j in range(i + 1, n_reflections):
368
                axes[j] = axes[j-1] + 1
369
            batches.extend(_reflect_in_axes(positive_int_lattice, axes))
370
371
    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
372
    for i in range(len(batches)):
373
        int_lattice[i] = batches[i]
374
375
    return int_lattice
376
377
378
@numba.njit(cache=True, fastmath=True)
379
def _reflect_in_axes(
380
        positive_int_lattice: np.ndarray, axes: np.ndarray
381
) -> np.ndarray:
382
    """Reflect points in `positive_int_lattice` in the axes described by
383
    `axes`, without duplicating invariant points.
384
    """
385
    not_on_axes = (positive_int_lattice[:, axes] == 0).sum(axis=-1) == 0
386
    int_lattice = positive_int_lattice[not_on_axes]
387
    int_lattice[:, axes] *= -1
388
    return int_lattice
389
390
391
@numba.njit(cache=True, fastmath=True)
392
def _close_lattice_points(
393
        int_lattice: np.ndarray, cell: np.ndarray, bound: float
394
) -> np.ndarray:
395
    """Given integer lattice points, a unit cell and ``bound``, return
396
    lattice points which are close enough such that the corresponding
397
    motif copy could contain nearest neighbours. ``bound`` should be
398
    equal to (max_d + motif_diam) ** 2, where max_d is the maximum
399
    k-th nearest neighbour distance found so far and motif_diam is the
400
    largest distance between any point in the query set and motif.
401
    """
402
403
    lattice = int_lattice @ cell
404
    inds = []
405
    for i in range(len(lattice)):
406
        s = 0
407
        for xyz in lattice[i]:
408
            s += xyz ** 2
409
        if s < bound:
410
            inds.append(i)
411
    ret = np.empty((len(inds), lattice.shape[-1]), dtype=np.float64)
412
    for i in range(len(inds)):
413
        ret[i] = lattice[inds[i]]
414
    return ret
415
416
417
@numba.njit(cache=True, fastmath=True)
418
def _lattice_to_cloud(motif: np.ndarray, lattice: np.ndarray) -> np.ndarray:
419
    """Transform a batch of lattice points (generated by
420
    _generate_integer_lattice then mutliplied by the cell) into a cloud
421
    of points from a periodic set.
422
    """
423
424
    m = len(motif)
425
    layer = np.empty((m * len(lattice), motif.shape[-1]), dtype=np.float64)
426
    i1 = 0
427
    for translation in lattice:
428
        i2 = i1 + m
429
        layer[i1:i2] = motif + translation
430
        i1 = i2
431
    return layer
432
433
434
@numba.njit(cache=True, fastmath=True)
435
def _int_lattice_to_cloud(
436
        motif: np.ndarray, cell: np.ndarray, int_lattice: np.ndarray
437
) -> np.ndarray:
438
    """Transform a batch of integer lattice points (generated by
439
    _generate_integer_lattice) into a cloud of points from a periodic
440
    set.
441
    """
442
    return _lattice_to_cloud(motif, int_lattice @ cell)
443
444
445
@numba.njit(cache=True, fastmath=True)
446
def _cdist_sqeuclidean(arr1, arr2):
447
    """Squared Euclidean distance between points in ``arr1`` and
448
    ``arr2``."""
449
    n1, n2 = arr1.shape[0], arr2.shape[0]
450
    res = np.empty((n1, n2), dtype=np.float64)
451
    for i in range(n1):
452
        for j in range(n2):
453
            s = 0.
454
            for n in range(arr1.shape[-1]):
455
                s += (arr1[i, n] - arr2[j, n]) ** 2
456
            res[i, j] = s
457
    return res
458
459
460
@numba.njit(cache=True, fastmath=True)
461
def _max_in_column(arr, col):
462
    """Return maximum value in chosen column col of array arr. Assumes
463
    all values of arr are non-negative."""
464
    ret = 0
465
    for i in range(arr.shape[0]):
466
        v = arr[i, col]
467
        if v > ret:
468
            ret = v
469
    return ret
470
471
472
@numba.njit(cache=True, fastmath=True)
473
def _max_in_columns(arr, max_col):
474
    """Return maximum value in all columns up to a chosen column col of
475
    array arr."""
476
    ret = 0
477
    for col in range(max_col):
478
        v = _max_in_column(arr, col)
479
        if v > ret:
480
            ret = v
481
    return ret
482
483
484
@numba.njit(cache=True, fastmath=True)
485
def _merge_sorted_arrays(dists, dists_):
486
    """Merge two 2D arrays sorted along last axis into one sorted array
487
    with same number of columns as ``dists``. Optimised for the distance
488
    arrays in nearest_neighbours, where ``dists`` will contain most of
489
    the smallest elements and only a few values in later columns will
490
    need to be replaced with values in ``dists_``.
491
    """
492
493
    m, n_new_points = dists_.shape
494
    ret = np.copy(dists)
495
496
    for i in range(m):
497
        # Traverse row backwards until value smaller than dists_[i, 0]
498
        j = 0
499
        dp_ = 0
500
        d_ = dists_[i, dp_]
501
        while True:
502
            j -= 1
503
            if dists[i, j] <= d_:
504
                j += 1
505
                break
506
507
        if j == 0:  # If dists_[i, 0] >= dists[i, -1], no need to insert
508
            continue
509
510
        # dp points to dists[i], dp_ points to dists_[i].
511
        # fill ret with the larger dist, then increment pointers and repeat.
512
        dp = j
513
        d = dists[i, dp]
514
515
        while j < 0:
516
            if d <= d_:
517
                ret[i, j] = d
518
                dp += 1
519
                d = dists[i, dp]
520
            else:
521
                ret[i, j] = d_
522
                dp_ += 1
523
                if dp_ < n_new_points:
524
                    d_ = dists_[i, dp_]
525
                else:  # ran out of points in dists_
526
                    d_ = np.inf
527
            j += 1
528
529
    return ret
530
531
532
@numba.njit(cache=True, fastmath=True)
533
def _merge_sorted_arrays_inds(dists, dists_, inds, inds_):
534
    """The same as _merge_sorted_arrays, but also merges two arrays
535
    ``inds`` and ``inds_`` in the same pattern ``dists`` and ``dists_``
536
    are merged.
537
    """
538
539
    m, n_new_points = dists_.shape
540
    ret_dists = np.copy(dists)
541
    ret_inds = np.copy(inds)
542
543
    for i in range(m):
544
        # Traverse row backwards until value smaller than dists_[i, 0]
545
        j = 0
546
        dp_ = 0
547
        d_ = dists_[i, dp_]
548
        p_ = inds_[i, dp_]
549
        while True:
550
            j -= 1
551
            if dists[i, j] <= d_:
552
                j += 1
553
                break
554
555
        if j == 0:  # If dists_[i, 0] >= dists[i, -1], no need to insert
556
            continue
557
558
        # dp points to dists[i], dp_ points to dists_[i].
559
        # fill ret_dists with the larger dist, then increment pointers and repeat.
560
        dp = j
561
        d = dists[i, dp]
562
        p = inds[i, dp]
563
564
        while j < 0:
565
            if d <= d_:
566
                ret_dists[i, j] = d
567
                ret_inds[i, j] = p
568
                dp += 1
569
                d = dists[i, dp]
570
                p = inds[i, dp]
571
            else:
572
                ret_dists[i, j] = d_
573
                ret_inds[i, j] = p_
574
                dp_ += 1
575
                if dp_ < n_new_points:
576
                    d_ = dists_[i, dp_]
577
                    p_ = inds_[i, dp_]
578
                else:  # ran out of points in dists_
579
                    d_ = np.inf
580
            j += 1
581
582
    return ret_dists, ret_inds
583
584
585
def nearest_neighbours_minval(
586
        motif: np.ndarray, cell: np.ndarray, min_val: float
587
) -> Tuple[np.ndarray, ...]:
588
    """Return the same ``dists``/PDD matrix as ``nearest_neighbours``,
589
    but with enough columns such that all values in the last column are
590
    at least ``min_val``. Unlike ``nearest_neighbours``, does not take a
591
    query array ``x`` but only finds neighbours to motif points, and
592
    does not return the point cloud or indices of the nearest
593
    neighbours. Used in ``PDD_reconstructable``.
594
    
595
    TODO: this function should be updated in line with
596
    nearest_neighbours.
597
    """
598
599
    # Generate initial cloud of points from the periodic set
600
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
601
    int_lat_generator = iter(int_lat_generator)
602
    cloud = []
603
    for _ in range(3):
604
        cloud.append(_lattice_to_cloud(motif, next(int_lat_generator) @ cell))
605
    cloud = np.concatenate(cloud)
606
607
    # Find k neighbours in the point cloud for points in motif
608
    dists_, inds = KDTree(
609
        cloud, leafsize=30, compact_nodes=False, balanced_tree=False
610
    ).query(motif, k=cloud.shape[0])
611
    dists = np.zeros_like(dists_, dtype=np.float64)
612
613
    # Add layers & find k nearest neighbours until all distances smaller than
614
    # min_val don't change
615
    max_cdist = np.amax(cdist(motif, motif))
616
    while True:
617
        if np.all(dists_[:, -1] >= min_val):
618
            col = np.argwhere(np.all(dists_ >= min_val, axis=0))[0][0] + 1
619
            if np.array_equal(dists[:, :col], dists_[:, :col]):
620
                break
621
        dists = dists_
622
        lattice = next(int_lat_generator) @ cell
623
        closest_dist_bound = np.linalg.norm(lattice, axis=-1) - max_cdist
624
        is_close = closest_dist_bound <= np.amax(dists_[:, -1])
625
        if not np.any(is_close):
626
            break
627
        cloud = np.vstack((cloud, _lattice_to_cloud(motif, lattice[is_close])))
628
        dists_, inds = KDTree(
629
            cloud, leafsize=30, compact_nodes=False, balanced_tree=False
630
        ).query(motif, k=cloud.shape[0])
631
632
    k = np.argwhere(np.all(dists >= min_val, axis=0))[0][0]
633
    return dists_[:, 1:k+1], cloud, inds
634
635
636
def generate_concentric_cloud(motif, cell):
637
    """Generates batches of points from a periodic set given by (motif,
638
    cell) which get successively further away from the origin.
639
640
    Each yield gives all points (that have not already been yielded)
641
    which lie in a unit cell whose corner lattice point was generated by
642
    ``generate_integer_lattice(motif.shape[1])``.
643
644
    Parameters
645
    ----------
646
    motif : :class:`numpy.ndarray`
647
        Cartesian representation of the motif, shape (no points, dims).
648
    cell : :class:`numpy.ndarray`
649
        Cartesian representation of the unit cell, shape (dims, dims).
650
651
    Yields
652
    -------
653
    :class:`numpy.ndarray`
654
        Yields arrays of points from the periodic set.
655
    """
656
657
    int_lat_generator = _generate_integer_lattice(cell.shape[0])
658
    for layer in int_lat_generator:
659
        yield _lattice_to_cloud(motif, layer @ cell)
660