Passed
Push — master ( 3ef848...37d7fb )
by Daniel
04:06
created

_reflect_positive_lattice()   B

Complexity

Conditions 8

Size

Total Lines 36
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 36
rs 7.3333
c 0
b 0
f 0
cc 8
nop 1
1
"""Implements core function nearest_neighbours used for AMD and PDD
2
calculations.
3
"""
4
5
import collections
6
from typing import Tuple, Iterable
7
from itertools import product
8
9
import numba
10
import numpy as np
11
import numpy.typing as npt
12
from scipy.spatial import KDTree
13
14
15
def nearest_neighbours(
16
        motif: npt.NDArray,
17
        cell: npt.NDArray,
18
        x: npt.NDArray,
19
        k: int
20
) -> Tuple[npt.NDArray[np.float64]]:
21
    """
22
    Given a periodic set represented by (motif, cell) and an integer k,
23
    find the k nearest neighbours in the periodic set to points in x.
24
25
    Parameters
26
    ----------
27
    motif : :class:`numpy.ndarray`
28
        Orthogonal (Cartesian) coords of the motif, shape (no points,
29
        dims).
30
    cell : :class:`numpy.ndarray`
31
        Orthogonal (Cartesian) coords of the unit cell, shape (dims,
32
        dims).
33
    x : :class:`numpy.ndarray`
34
        Array of points to query for neighbours. For invariants of
35
        crystals this is the asymmetric unit.
36
    k : int
37
        Number of nearest neighbours to find for each point in x.
38
39
    Returns
40
    -------
41
    pdd : numpy.ndarray
42
        Array shape (motif.shape[0], k) of distances from points in x
43
        to their k nearest neighbours in the periodic set, in order.
44
        E.g. pdd[m][n] is the distance from x[m] to its n-th nearest
45
        neighbour in the periodic set.
46
    cloud : numpy.ndarray
47
        Collection of points in the periodic set that was generated
48
        during the nearest neighbour search.
49
    inds : numpy.ndarray
50
        Array shape (motif.shape[0], k) containing the indices of
51
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
52
        the m-th motif point is cloud[inds[m][n]].
53
    """
54
55
    cloud_generator = generate_concentric_cloud(motif, cell)
56
    n_points = 0
57
    cloud = []
58
    while n_points <= k:
59
        l = next(cloud_generator)
60
        n_points += l.shape[0]
61
        cloud.append(l)
62
    cloud.append(next(cloud_generator))
63
    cloud = np.concatenate(cloud)
64
65
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
66
    pdd_, inds = tree.query(x, k=k+1, workers=-1)
67
    pdd = np.zeros_like(pdd_, dtype=np.float64)
68
69
    while not np.allclose(pdd, pdd_, atol=1e-10, rtol=0):
70
        pdd = pdd_
71
        cloud = np.vstack((cloud, next(cloud_generator)))
72
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
73
        pdd_, inds = tree.query(x, k=k+1, workers=-1)
74
75
    return pdd_[:, 1:], cloud, inds[:, 1:]
76
77
78
def nearest_neighbours_minval(
79
        motif: npt.NDArray,
80
        cell: npt.NDArray,
81
        min_val: float
82
) -> npt.NDArray[np.float64]:
83
    """The same as nearest_neighbours except a value is given instead of
84
    an integer k and the result has at least enough columns so all
85
    values in the last column are at least the given value.
86
    """
87
88
    cloud_generator = generate_concentric_cloud(motif, cell)
89
90
    cloud = []
91
    for _ in range(3):
92
        cloud.append(next(cloud_generator))
93
    cloud = np.concatenate(cloud)
94
95
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
96
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
97
    pdd = np.zeros_like(pdd_)
98
99
    while True:
100
        if np.all(pdd[:, -1] >= min_val):
101
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0] + 1
102
            if np.array_equal(pdd[:, :col_where], pdd_[:, :col_where]):
103
                break
104
        pdd = pdd_
105
        cloud = np.vstack((cloud, next(cloud_generator)))
106
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
107
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
108
109
    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
110
    return pdd[:, 1:k+1]
111
112
113
def generate_concentric_cloud(
114
        motif: npt.NDArray,
115
        cell: npt.NDArray
116
) -> Iterable[npt.NDArray[np.float64]]:
117
    """
118
    Generates batches of points from a periodic set given by (motif,
119
    cell) which get successively further away from the origin.
120
121
    Each yield gives all points (that have not already been yielded)
122
    which lie in a unit cell whose corner lattice point was generated by
123
    ``generate_integer_lattice(motif.shape[1])``.
124
125
    Parameters
126
    ----------
127
    motif : :class:`numpy.ndarray`
128
        Cartesian representation of the motif, shape (no points, dims).
129
    cell : :class:`numpy.ndarray`
130
        Cartesian representation of the unit cell, shape (dims, dims).
131
132
    Yields
133
    -------
134
    :class:`numpy.ndarray`
135
        Yields arrays of points from the periodic set.
136
    """
137
138
    m = len(motif)
139
140
    for int_lattice in generate_integer_lattice(cell.shape[0]):
141
142
        lattice = int_lattice @ cell
143
        layer = np.empty((m * len(lattice), cell.shape[0]), dtype=np.float64)
144
        i1 = 0
145
        for translation in lattice:
146
            i2 = i1 + m
147
            layer[i1:i2] = motif + translation
148
            i1 = i2
149
150
        yield layer
151
152
153
def generate_integer_lattice(dims: int) -> Iterable[npt.NDArray[np.float64]]:
154
    """Generates batches of integer lattice points. Each yield gives all
155
    points (that have not already been yielded) inside a sphere centered
156
    at the origin with radius d. d starts at 0 and increments by 1 on
157
    each loop.
158
159
    Parameters
160
    ----------
161
    dims : int
162
        The dimension of Euclidean space the lattice is in.
163
164
    Yields
165
    -------
166
    :class:`numpy.ndarray`
167
        Yields arrays of integer points in dims dimensional Euclidean
168
        space.
169
    """
170
171
    ymax = collections.defaultdict(int)
172
    d = 0
173
174
    if dims == 1:
175
        yield np.array([[0]])
176
        while True:
177
            d += 1
178
            yield np.array([[-d], [d]])
179
180
    while True:
181
        positive_int_lattice = []
182
        while True:
183
            batch = []
184
            for xy in product(range(d + 1), repeat=dims-1):
185
                if _dist(xy, ymax[xy]) <= d ** 2:
186
                    batch.append((*xy, ymax[xy]))
187
                    ymax[xy] += 1
188
            if not batch:
189
                break
190
            positive_int_lattice.extend(batch)
191
192
        yield _reflect_positive_lattice(np.array(positive_int_lattice))
193
        d += 1
194
195
196
@numba.njit()
197
def _dist(xy: Tuple[float, float], z: float) -> float:
198
    s = z ** 2
199
    for val in xy:
200
        s += val ** 2
201
    return s
202
203
204
@numba.njit()
205
def _reflect_positive_lattice(
206
        positive_int_lattice: npt.NDArray
207
) -> npt.NDArray[np.float64]:
208
    """Reflect a set of points in the +ve quadrant in all axes. Does not
209
    duplicate points lying on the axes themselves.
210
    """
211
212
    dims = positive_int_lattice.shape[-1]
213
    batches = []
214
    batches.extend(positive_int_lattice)
215
216
    for n_reflections in range(1, dims + 1):
217
218
        indices = np.arange(n_reflections)
219
        batches.extend(_reflect_batch(positive_int_lattice, indices))
220
221
        while True:
222
            i = n_reflections - 1
223
            for _ in range(n_reflections):
224
                if indices[i] != i + dims - n_reflections:
225
                    break
226
                i -= 1
227
            else:
228
                break
229
            indices[i] += 1
230
            for j in range(i+1, n_reflections):
231
                indices[j] = indices[j-1] + 1
232
233
            batches.extend(_reflect_batch(positive_int_lattice, indices))
234
235
    int_lattice = np.empty(shape=(len(batches), dims), dtype=np.float64)
236
    for i in range(len(batches)):
237
        int_lattice[i] = batches[i]
238
239
    return int_lattice
240
241
242
@numba.njit()
243
def _reflect_batch(
244
        positive_int_lattice: npt.NDArray,
245
        indices: npt.NDArray
246
) -> npt.NDArray:
247
    """Takes a collection of points in any dimension and the indices of
248
    axes to reflect in, returning a batch of reflected points not
249
    including any points which are invariant under the reflections.
250
    """
251
252
    where_on_axes = (positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0
253
    batch = positive_int_lattice[where_on_axes]
254
    batch[:, indices] *= -1
255
    return batch
256