Passed
Push — master ( 8cfb03...7c5bc8 )
by Daniel
07:13
created

generate_concentric_cloud()   A

Complexity

Conditions 3

Size

Total Lines 36
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 36
rs 9.85
c 0
b 0
f 0
cc 3
nop 2
1
"""Implements core function nearest_neighbours used for AMD and PDD calculations."""
2
3
import collections
4
from typing import Iterable, Optional
5
from itertools import product
6
7
import numba
8
import numpy as np
9
from scipy.spatial import KDTree
10
11
12
def nearest_neighbours(
13
        motif: np.ndarray,
14
        cell: np.ndarray,
15
        k: int,
16
        asymmetric_unit: Optional[np.ndarray] = None):
17
    """
18
    Given a periodic set represented by (motif, cell) and an integer k, find
19
    the k nearest neighbours of the motif points in the periodic set.
20
21
    Note that cloud and inds are not used yet but may be in the future.
22
23
    Parameters
24
    ----------
25
    motif : numpy.ndarray
26
        Cartesian coords of the full motif, shape (no points, dims).
27
    cell : numpy.ndarray
28
        Cartesian coords of the unit cell, shape (dims, dims).
29
    k : int
30
        Number of nearest neighbours to find for each motif point.
31
    asymmetric_unit : numpy.ndarray, optional
32
        Indices pointing to an asymmetric unit in motif.
33
34
    Returns
35
    -------
36
    pdd : numpy.ndarray
37
        An array shape (motif.shape[0], k) of distances from each motif
38
        point to its k nearest neighbours in order. Points do not count
39
        as their own nearest neighbour. E.g. the distance to the n-th
40
        nearest neighbour of the m-th motif point is pdd[m][n].
41
    cloud : numpy.ndarray
42
        The collection of points in the periodic set that were generated
43
        during the nearest neighbour search.
44
    inds : numpy.ndarray
45
        An array shape (motif.shape[0], k) containing the indices of
46
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
47
        the m-th motif point is cloud[inds[m][n]].
48
    """
49
50
    if asymmetric_unit is not None:
51
        asym_unit = motif[asymmetric_unit]
52
    else:
53
        asym_unit = motif
54
55
    cloud_generator = generate_concentric_cloud(motif, cell)
56
    n_points = 0
57
    cloud = []
58
    while n_points <= k:
59
        l = next(cloud_generator)
60
        n_points += l.shape[0]
61
        cloud.append(l)
62
    cloud.append(next(cloud_generator))
63
    cloud = np.concatenate(cloud)
64
65
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
66
    pdd_, inds = tree.query(asym_unit, k=k+1, workers=-1)
67
    pdd = np.zeros_like(pdd_)
68
69
    while not np.allclose(pdd, pdd_, atol=1e-12, rtol=0):
70
        pdd = pdd_
71
        cloud = np.vstack((cloud, next(cloud_generator)))
72
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
73
        pdd_, inds = tree.query(asym_unit, k=k+1, workers=-1)
74
75
    return pdd_[:, 1:], cloud, inds[:, 1:]
76
77
78
def nearest_neighbours_minval(motif, cell, min_val):
79
    """PDD large enough to be reconstructed from
80
    (such that last col values all > 2 * diam(cell))."""
81
82
    cloud_generator = generate_concentric_cloud(motif, cell)
83
84
    cloud = []
85
    for _ in range(3):
86
        cloud.append(next(cloud_generator))
87
88
    cloud = np.concatenate(cloud)
89
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
90
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
91
    pdd = np.zeros_like(pdd_)
92
93
    while True:
94
        if np.all(pdd[:, -1] >= min_val):
95
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
96
            if np.array_equal(pdd[:, :col_where+1], pdd_[:, :col_where+1]):
97
                break
98
99
        pdd = pdd_
100
        cloud = np.vstack((cloud, next(cloud_generator)))
101
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
102
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
103
104
    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
105
106
    return pdd[:, 1:k+1]
107
108
109
def generate_concentric_cloud(
110
        motif: np.ndarray,
111
        cell: np.ndarray
112
) -> Iterable[np.ndarray]:
113
    """
114
    Generates batches of points from a periodic set given by (motif, cell)
115
    which get successively further away from the origin.
116
117
    Each yield gives all points (that have not already been yielded) which
118
    lie in a unit cell whose corner lattice point was generated by
119
    generate_integer_lattice(motif.shape[1]).
120
121
    Parameters
122
    ----------
123
    motif : ndarray
124
        Cartesian representation of the motif, shape (no points, dims).
125
    cell : ndarray
126
        Cartesian representation of the unit cell, shape (dims, dims).
127
128
    Yields
129
    -------
130
    ndarray
131
        Yields arrays of points from the periodic set.
132
    """
133
134
    m = len(motif)
135
    int_lattice_generator = generate_integer_lattice(cell.shape[0])
136
137
    while True:
138
        int_lattice = next(int_lattice_generator) @ cell
0 ignored issues
show
introduced by
Do not raise StopIteration in generator, use return statement instead
Loading history...
139
        layer = np.empty((m * len(int_lattice), cell.shape[0]))
140
141
        for i, translation in enumerate(int_lattice):
142
            layer[m*i:m*(i+1)] = motif + translation
143
144
        yield layer
145
146
147
def generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
148
    """Generates batches of integer lattice points.
149
150
    Each yield gives all points (that have not already been yielded)
151
    inside a sphere centered at the origin with radius d. d starts at 0
152
    and increments by 1 on each loop.
153
154
    Parameters
155
    ----------
156
    dims : int
157
        The dimension of Euclidean space the lattice is in.
158
159
    Yields
160
    -------
161
    ndarray
162
        Yields arrays of integer points in dims dimensional Euclidean space.
163
    """
164
165
    ymax = collections.defaultdict(int)
166
    d = 0
167
168
    if dims == 1:
169
        yield np.array([[0]])
170
        while True:
171
            d += 1
172
            yield np.array([[-d], [d]])
173
174
    while True:
175
        positive_int_lattice = []
176
        while True:
177
            batch = []
178
            for xy in product(range(d+1), repeat=dims-1):
179
                if _dist(xy, ymax[xy]) <= d**2:
180
                    batch.append((*xy, ymax[xy]))
181
                    ymax[xy] += 1
182
            if not batch:
183
                break
184
            positive_int_lattice += batch
185
186
        positive_int_lattice = np.array(positive_int_lattice)
187
        batches = _reflect_positive_lattice(positive_int_lattice)
188
        yield np.array(np.concatenate(batches))
189
        d += 1
190
191
192
@numba.njit()
193
def _reflect_positive_lattice(positive_int_lattice):
194
    """Reflect a set of points in the +ve quadrant in all axes."""
195
    dims = positive_int_lattice.shape[-1]
196
    batches = [positive_int_lattice]
197
198
    for n_reflections in range(1, dims + 1):
199
200
        indices = np.arange(n_reflections)
201
        batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
202
        batch[:, indices] *= -1
203
        batches.append(batch)
204
205
        while True:
206
            i = n_reflections - 1
207
            for _ in range(n_reflections):
208
                if indices[i] != i + dims - n_reflections:
209
                    break
210
                i -= 1
211
            else:
212
                break
213
            indices[i] += 1
214
            for j in range(i+1, n_reflections):
215
                indices[j] = indices[j-1] + 1
216
            
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
217
            batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
218
            batch[:, indices] *= -1
219
            batches.append(batch)
220
221
    return batches
222
223
224
@numba.njit()
225
def _dist(xy, z):
226
    s = z ** 2
227
    for val in xy:
228
        s += val ** 2
229
    return s
230
231
232
# # @numba.njit()
233
# def cartesian_product(n, repeat):
234
#     arrays = [np.arange(n)] * repeat
235
#     arr = np.empty(tuple([n] * repeat + [repeat]), dtype=np.int64)
236
#     for i, a in enumerate(np.ix_(*arrays)):
237
#         arr[..., i] = a
238
#     return arr.reshape(-1, repeat)
239