Test Failed
Push — master ( 9779af...19222a )
by Daniel
02:55
created

amd._nns.cartesian_product()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 2
nop 1
1
"""Implements core function nearest_neighbours used for AMD and PDD calculations."""
2
3
import collections
4
from typing import Iterable
5
from itertools import product
6
7
import numba
8
import numpy as np
9
from scipy.spatial import KDTree
10
11
12
def nearest_neighbours(
13
        motif: np.ndarray,
14
        cell: np.ndarray,
15
        x: np.ndarray,
16
        k: int):
17
    """
18
    Given a periodic set represented by (motif, cell) and an integer k, find
19
    the k nearest neighbours in the periodic set to points in x.
20
21
    Parameters
22
    ----------
23
    motif : numpy.ndarray
24
        Orthogonal (Cartesian) coords of the motif, shape (no points, dims).
25
    cell : numpy.ndarray
26
        Orthogonal (Cartesian) coords of the unit cell, shape (dims, dims).
27
    x : numpy.ndarray
28
        Array of points to query for neighbours. For invariants of crystals
29
        this is the asymmetric unit.
30
    k : int
31
        Number of nearest neighbours to find for each point in x.
32
33
    Returns
34
    -------
35
    pdd : numpy.ndarray
36
        Array shape (motif.shape[0], k) of distances from points in x
37
        to their k nearest neighbours in the periodic set, in order.
38
        E.g. pdd[m][n] is the distance from x[m] to its n-th nearest
39
        neighbour in the periodic set.
40
    cloud : numpy.ndarray
41
        Collection of points in the periodic set that was generated
42
        during the nearest neighbour search.
43
    inds : numpy.ndarray
44
        Array shape (motif.shape[0], k) containing the indices of
45
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
46
        the m-th motif point is cloud[inds[m][n]].
47
    """
48
49
    cloud_generator = generate_concentric_cloud(motif, cell)
50
    n_points = 0
51
    cloud = []
52
    while n_points <= k:
53
        l = next(cloud_generator)
54
        n_points += l.shape[0]
55
        cloud.append(l)
56
    cloud.append(next(cloud_generator))
57
    cloud = np.concatenate(cloud)
58
59
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
60
    pdd_, inds = tree.query(x, k=k+1, workers=-1)
61
    pdd = np.zeros_like(pdd_)
62
63
    while not np.allclose(pdd, pdd_, atol=1e-10, rtol=0):
64
        pdd = pdd_
65
        cloud = np.vstack((cloud, next(cloud_generator)))
66
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
67
        pdd_, inds = tree.query(x, k=k+1, workers=-1)
68
69
    return pdd_[:, 1:], cloud, inds[:, 1:]
70
71
72
def nearest_neighbours_minval(motif, cell, min_val):
73
    """The same as nearest_neighbours except a value is given instead of an
74
    integer k and the result has at least enough columns so all values in
75
    the last column are at least the given value."""
76
77
    cloud_generator = generate_concentric_cloud(motif, cell)
78
79
    cloud = []
80
    for _ in range(3):
81
        cloud.append(next(cloud_generator))
82
83
    cloud = np.concatenate(cloud)
84
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
85
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
86
    pdd = np.zeros_like(pdd_)
87
88
    while True:
89
        if np.all(pdd[:, -1] >= min_val):
90
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
91
            if np.array_equal(pdd[:, :col_where+1], pdd_[:, :col_where+1]):
92
                break
93
94
        pdd = pdd_
95
        cloud = np.vstack((cloud, next(cloud_generator)))
96
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
97
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
98
99
    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
100
101
    return pdd[:, 1:k+1]
102
103
104
def generate_concentric_cloud(
105
        motif: np.ndarray,
106
        cell: np.ndarray
107
) -> Iterable[np.ndarray]:
108
    """
109
    Generates batches of points from a periodic set given by (motif, cell)
110
    which get successively further away from the origin.
111
112
    Each yield gives all points (that have not already been yielded) which
113
    lie in a unit cell whose corner lattice point was generated by
114
    generate_integer_lattice(motif.shape[1]).
115
116
    Parameters
117
    ----------
118
    motif : ndarray
119
        Cartesian representation of the motif, shape (no points, dims).
120
    cell : ndarray
121
        Cartesian representation of the unit cell, shape (dims, dims).
122
123
    Yields
124
    -------
125
    ndarray
126
        Yields arrays of points from the periodic set.
127
    """
128
129
    m = len(motif)
130
131
    for int_lattice in generate_integer_lattice(cell.shape[0]):
132
133
        lattice = int_lattice @ cell
134
        layer = np.empty((m * len(lattice), cell.shape[0]))
135
136
        for i, translation in enumerate(lattice):
137
            layer[m*i:m*(i+1)] = motif + translation
138
139
        yield layer
140
141
142
def generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
143
    """Generates batches of integer lattice points.
144
145
    Each yield gives all points (that have not already been yielded)
146
    inside a sphere centered at the origin with radius d. d starts at 0
147
    and increments by 1 on each loop.
148
149
    Parameters
150
    ----------
151
    dims : int
152
        The dimension of Euclidean space the lattice is in.
153
154
    Yields
155
    -------
156
    ndarray
157
        Yields arrays of integer points in dims dimensional Euclidean space.
158
    """
159
160
    ymax = collections.defaultdict(int)
161
    d = 0
162
163
    if dims == 1:
164
        yield np.array([[0]])
165
        while True:
166
            d += 1
167
            yield np.array([[-d], [d]])
168
169
    while True:
170
        positive_int_lattice = []
171
        while True:
172
            batch = []
173
            for xy in product(range(d+1), repeat=dims-1):
174
                if _dist(xy, ymax[xy]) <= d**2:
175
                    batch.append((*xy, ymax[xy]))
176
                    ymax[xy] += 1
177
            if not batch:
178
                break
179
            positive_int_lattice += batch
180
181
        positive_int_lattice = np.array(positive_int_lattice)
182
        batches = _reflect_positive_lattice(positive_int_lattice)
183
        yield np.concatenate(batches)
184
        d += 1
185
186
187
@numba.njit()
188
def _dist(xy, z):
189
    s = z ** 2
190
    for val in xy:
191
        s += val ** 2
192
    return s
193
194
195
# def generate_even_cloud(motif, cell):
196
#     m = len(motif)
197
#     lattice_generator = generate_even_lattice(cell)
198
199
#     while True:
200
#         lattice = next(lattice_generator)
201
#         layer = np.empty((m * len(lattice), cell.shape[0]))
202
203
#         for i, translation in enumerate(lattice):
204
#             layer[m * i : m * (i + 1)] = motif + translation
205
206
#         yield layer
207
208
209
# def generate_even_lattice(cell):
210
#     inv_cell = np.linalg.inv(cell)
211
#     n = cell.shape[0]
212
213
#     cell_lengths = np.linalg.norm(cell, axis=-1)
214
#     ratios = np.amax(cell_lengths) / cell_lengths
215
#     approx_ratios = np.copy(ratios)
216
#     xyz = np.zeros(n, dtype=int)
217
218
#     while True:
219
220
#         xyz_ = np.floor(approx_ratios).astype(int)
221
#         pve_int_lattice = []
222
#         for axis in range(n):
223
#             generators = [range(0, xyz_[d]) for d in range(axis)]
224
#             generators.append(range(xyz[axis], xyz_[axis]))
225
#             generators.extend(range(0, xyz[d]) for d in range(axis + 1, n))
226
#             pve_int_lattice.extend(product(*generators))
227
228
#         pve_int_lattice = np.array(pve_int_lattice)
229
#         pos_int_lat_cloud = np.concatenate(_reflect_positive_lattice(pve_int_lattice))
230
#         yield pos_int_lat_cloud @ cell
231
#         xyz = xyz_
232
#         approx_ratios += ratios
233
234
235
@numba.njit()
236
def _reflect_positive_lattice(positive_int_lattice):
237
    """Reflect a set of points in the +ve quadrant in all axes.
238
    Does not duplicate points lying on the axes themselves."""
239
    dims = positive_int_lattice.shape[-1]
240
    batches = [positive_int_lattice]
241
242
    for n_reflections in range(1, dims + 1):
243
244
        indices = np.arange(n_reflections)
245
        batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
246
        batch[:, indices] *= -1
247
        batches.append(batch)
248
249
        while True:
250
            i = n_reflections - 1
251
            for _ in range(n_reflections):
252
                if indices[i] != i + dims - n_reflections:
253
                    break
254
                i -= 1
255
            else:
256
                break
257
            indices[i] += 1
258
            for j in range(i+1, n_reflections):
259
                indices[j] = indices[j-1] + 1
260
261
            batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
262
            batch[:, indices] *= -1
263
            batches.append(batch)
264
265
    return batches
266
267
# # @numba.njit()
268
# def cartesian_product(n, repeat):
269
#     arrays = [np.arange(n)] * repeat
270
#     arr = np.empty(tuple([n] * repeat + [repeat]), dtype=np.int64)
271
#     for i, a in enumerate(np.ix_(*arrays)):
272
#         arr[..., i] = a
273
#     return arr.reshape(-1, repeat)
274
275
# def cartesian_product(*arrays):
276
#     la = len(arrays)
277
#     # dtype = np.result_type(*arrays)
278
#     arr = np.empty([len(a) for a in arrays] + [la])
279
#     for i, a in enumerate(np.ix_(*arrays)):
280
#         arr[..., i] = a
281
#     return arr.reshape(-1, la)
282