Passed
Push — master ( c026c2...632976 )
by Daniel
06:58
created

amd._nns.generate_even_lattice()   A

Complexity

Conditions 3

Size

Total Lines 22
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 22
rs 9.45
c 0
b 0
f 0
cc 3
nop 1
1
"""Implements core function nearest_neighbours used for AMD and PDD calculations."""
2
3
import collections
4
from typing import Iterable
5
from itertools import product
6
7
import numba
8
import numpy as np
9
from scipy.spatial import KDTree
10
11
12
def nearest_neighbours(
13
        motif: np.ndarray,
14
        cell: np.ndarray,
15
        x: np.ndarray,
16
        k: int):
17
    """
18
    Given a periodic set represented by (motif, cell) and an integer k, find
19
    the k nearest neighbours in the periodic set to points in x.
20
21
    Parameters
22
    ----------
23
    motif : numpy.ndarray
24
        Orthogonal (Cartesian) coords of the motif, shape (no points, dims).
25
    cell : numpy.ndarray
26
        Orthogonal (Cartesian) coords of the unit cell, shape (dims, dims).
27
    x : numpy.ndarray
28
        Array of points to query for neighbours. For invariants of crystals
29
        this is the asymmetric unit.
30
    k : int
31
        Number of nearest neighbours to find for each point in x.
32
33
    Returns
34
    -------
35
    pdd : numpy.ndarray
36
        Array shape (motif.shape[0], k) of distances from points in x
37
        to their k nearest neighbours in the periodic set, in order.
38
        E.g. pdd[m][n] is the distance from x[m] to its n-th nearest
39
        neighbour in the periodic set.
40
    cloud : numpy.ndarray
41
        Collection of points in the periodic set that was generated
42
        during the nearest neighbour search.
43
    inds : numpy.ndarray
44
        Array shape (motif.shape[0], k) containing the indices of
45
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
46
        the m-th motif point is cloud[inds[m][n]].
47
    """
48
49
    cloud_generator = generate_concentric_cloud(motif, cell)
50
    n_points = 0
51
    cloud = []
52
    while n_points <= k:
53
        l = next(cloud_generator)
54
        n_points += l.shape[0]
55
        cloud.append(l)
56
    cloud.append(next(cloud_generator))
57
    cloud = np.concatenate(cloud)
58
59
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
60
    pdd_, inds = tree.query(x, k=k+1, workers=-1)
61
    pdd = np.zeros_like(pdd_)
62
63
    while not np.allclose(pdd, pdd_, atol=1e-10, rtol=0):
64
        pdd = pdd_
65
        cloud = np.vstack((cloud, next(cloud_generator)))
66
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
67
        pdd_, inds = tree.query(x, k=k+1, workers=-1)
68
69
    return pdd_[:, 1:], cloud, inds[:, 1:]
70
71
72
def nearest_neighbours_minval(motif, cell, min_val):
73
    """The same as nearest_neighbours except a value is given instead of an
74
    integer k and the result has at least enough columns so all values in 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
75
    the last column are at least the given value."""
76
77
    cloud_generator = generate_concentric_cloud(motif, cell)
78
79
    cloud = []
80
    for _ in range(3):
81
        cloud.append(next(cloud_generator))
82
83
    cloud = np.concatenate(cloud)
84
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
85
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
86
    pdd = np.zeros_like(pdd_)
87
88
    while True:
89
        if np.all(pdd[:, -1] >= min_val):
90
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
91
            if np.array_equal(pdd[:, :col_where+1], pdd_[:, :col_where+1]):
92
                break
93
94
        pdd = pdd_
95
        cloud = np.vstack((cloud, next(cloud_generator)))
96
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
97
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
98
99
    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
100
101
    return pdd[:, 1:k+1]
102
103
104
def generate_concentric_cloud(
105
        motif: np.ndarray,
106
        cell: np.ndarray
107
) -> Iterable[np.ndarray]:
108
    """
109
    Generates batches of points from a periodic set given by (motif, cell)
110
    which get successively further away from the origin.
111
112
    Each yield gives all points (that have not already been yielded) which
113
    lie in a unit cell whose corner lattice point was generated by
114
    generate_integer_lattice(motif.shape[1]).
115
116
    Parameters
117
    ----------
118
    motif : ndarray
119
        Cartesian representation of the motif, shape (no points, dims).
120
    cell : ndarray
121
        Cartesian representation of the unit cell, shape (dims, dims).
122
123
    Yields
124
    -------
125
    ndarray
126
        Yields arrays of points from the periodic set.
127
    """
128
129
    m = len(motif)
130
    int_lattice_generator = generate_integer_lattice(cell.shape[0])
131
132
    while True:
133
        lattice = next(int_lattice_generator) @ cell
0 ignored issues
show
introduced by
Do not raise StopIteration in generator, use return statement instead
Loading history...
134
        layer = np.empty((m * len(lattice), cell.shape[0]))
135
136
        for i, translation in enumerate(lattice):
137
            layer[m*i:m*(i+1)] = motif + translation
138
139
        yield layer
140
141
142
def generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
143
    """Generates batches of integer lattice points.
144
145
    Each yield gives all points (that have not already been yielded)
146
    inside a sphere centered at the origin with radius d. d starts at 0
147
    and increments by 1 on each loop.
148
149
    Parameters
150
    ----------
151
    dims : int
152
        The dimension of Euclidean space the lattice is in.
153
154
    Yields
155
    -------
156
    ndarray
157
        Yields arrays of integer points in dims dimensional Euclidean space.
158
    """
159
160
    ymax = collections.defaultdict(int)
161
    d = 0
162
163
    if dims == 1:
164
        yield np.array([[0]])
165
        while True:
166
            d += 1
167
            yield np.array([[-d], [d]])
168
169
    while True:
170
        positive_int_lattice = []
171
        while True:
172
            batch = []
173
            for xy in product(range(d+1), repeat=dims-1):
174
                if _dist(xy, ymax[xy]) <= d**2:
175
                    batch.append((*xy, ymax[xy]))
176
                    ymax[xy] += 1
177
            if not batch:
178
                break
179
            positive_int_lattice += batch
180
181
        positive_int_lattice = np.array(positive_int_lattice)
182
        batches = _reflect_positive_lattice(positive_int_lattice)
183
        yield np.concatenate(batches)
184
        d += 1
185
186
187
@numba.njit()
188
def _dist(xy, z):
189
    s = z ** 2
190
    for val in xy:
191
        s += val ** 2
192
    return s
193
194
195
def generate_even_cloud(motif, cell):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
196
    m = len(motif)
197
    lattice_generator = generate_even_lattice(cell)
198
199
    while True:
200
        lattice = next(lattice_generator)
0 ignored issues
show
introduced by
Do not raise StopIteration in generator, use return statement instead
Loading history...
201
        layer = np.empty((m * len(lattice), cell.shape[0]))
202
203
        for i, translation in enumerate(lattice):
204
            layer[m * i : m * (i + 1)] = motif + translation
205
206
        yield layer
207
208
209
def generate_even_lattice(cell):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
210
    n = cell.shape[0]
211
    cell_lengths = np.linalg.norm(cell, axis=-1)
212
    ratios = np.amax(cell_lengths) / cell_lengths
213
    approx_ratios = np.copy(ratios)
214
    xyz = np.zeros(n, dtype=int)
215
216
    while True:
217
218
        xyz_ = np.floor(approx_ratios).astype(int)
219
        pve_int_lattice = []
220
        for axis in range(n):
221
            generators = [range(0, xyz_[d]) for d in range(axis)]
222
            generators.append(range(xyz[axis], xyz_[axis]))
223
            generators.extend(range(0, xyz[d]) for d in range(axis + 1, n))
224
            pve_int_lattice.extend(product(*generators))
225
226
        pve_int_lattice = np.array(pve_int_lattice)
227
        pos_int_lat_cloud = np.concatenate(_reflect_positive_lattice(pve_int_lattice))
228
        yield pos_int_lat_cloud @ cell
229
        xyz = xyz_
230
        approx_ratios += ratios
231
232
233
@numba.njit()
234
def _reflect_positive_lattice(positive_int_lattice):
235
    """Reflect a set of points in the +ve quadrant in all axes."""
236
    dims = positive_int_lattice.shape[-1]
237
    batches = [positive_int_lattice]
238
239
    for n_reflections in range(1, dims + 1):
240
241
        indices = np.arange(n_reflections)
242
        batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
243
        batch[:, indices] *= -1
244
        batches.append(batch)
245
246
        while True:
247
            i = n_reflections - 1
248
            for _ in range(n_reflections):
249
                if indices[i] != i + dims - n_reflections:
250
                    break
251
                i -= 1
252
            else:
253
                break
254
            indices[i] += 1
255
            for j in range(i+1, n_reflections):
256
                indices[j] = indices[j-1] + 1
257
            
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
258
            batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
259
            batch[:, indices] *= -1
260
            batches.append(batch)
261
262
    return batches
263
264
265
# # @numba.njit()
266
# def cartesian_product(n, repeat):
267
#     arrays = [np.arange(n)] * repeat
268
#     arr = np.empty(tuple([n] * repeat + [repeat]), dtype=np.int64)
269
#     for i, a in enumerate(np.ix_(*arrays)):
270
#         arr[..., i] = a
271
#     return arr.reshape(-1, repeat)
272
273
274
def cartesian_product(*arrays):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
275
    la = len(arrays)
276
    # dtype = np.result_type(*arrays)
277
    arr = np.empty([len(a) for a in arrays] + [la])
278
    for i, a in enumerate(np.ix_(*arrays)):
279
        arr[...,i] = a
0 ignored issues
show
Coding Style introduced by
Exactly one space required after comma
Loading history...
280
    return arr.reshape(-1, la)
281