Passed
Push — master ( a4dae9...10a6a8 )
by Daniel
07:04
created

amd._nns   A

Complexity

Total Complexity 28

Size/Duplication

Total Lines 225
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 28
eloc 114
dl 0
loc 225
rs 10
c 0
b 0
f 0

6 Functions

Rating   Name   Duplication   Size   Complexity  
A _dist() 0 6 2
A generate_concentric_cloud() 0 36 3
B nearest_neighbours_minval() 0 30 5
A nearest_neighbours() 0 58 3
B generate_integer_lattice() 0 43 8
B _reflect_positive_lattice() 0 30 7
1
"""Implements core function nearest_neighbours used for AMD and PDD calculations."""
2
3
import collections
4
from typing import Iterable
5
from itertools import product
6
7
import numba
8
import numpy as np
9
from scipy.spatial import KDTree
10
11
12
def nearest_neighbours(
13
        motif: np.ndarray,
14
        cell: np.ndarray,
15
        x: np.ndarray,
16
        k: int):
17
    """
18
    Given a periodic set represented by (motif, cell) and an integer k, find
19
    the k nearest neighbours in the periodic set to points in x.
20
21
    Parameters
22
    ----------
23
    motif : numpy.ndarray
24
        Orthogonal (Cartesian) coords of the motif, shape (no points, dims).
25
    cell : numpy.ndarray
26
        Orthogonal (Cartesian) coords of the unit cell, shape (dims, dims).
27
    x : numpy.ndarray
28
        Array of points to query for neighbours. For invariants of crystals
29
        this is the asymmetric unit.
30
    k : int
31
        Number of nearest neighbours to find for each point in x.
32
33
    Returns
34
    -------
35
    pdd : numpy.ndarray
36
        Array shape (motif.shape[0], k) of distances from points in x
37
        to their k nearest neighbours in the periodic set, in order.
38
        E.g. pdd[m][n] is the distance from x[m] to its n-th nearest
39
        neighbour in the periodic set.
40
    cloud : numpy.ndarray
41
        Collection of points in the periodic set that was generated
42
        during the nearest neighbour search.
43
    inds : numpy.ndarray
44
        Array shape (motif.shape[0], k) containing the indices of
45
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
46
        the m-th motif point is cloud[inds[m][n]].
47
    """
48
49
    cloud_generator = generate_concentric_cloud(motif, cell)
50
    n_points = 0
51
    cloud = []
52
    while n_points <= k:
53
        l = next(cloud_generator)
54
        n_points += l.shape[0]
55
        cloud.append(l)
56
    cloud.append(next(cloud_generator))
57
    cloud = np.concatenate(cloud)
58
59
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
60
    pdd_, inds = tree.query(x, k=k+1, workers=-1)
61
    pdd = np.zeros_like(pdd_)
62
63
    while not np.allclose(pdd, pdd_, atol=1e-12, rtol=0):
64
        pdd = pdd_
65
        cloud = np.vstack((cloud, next(cloud_generator)))
66
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
67
        pdd_, inds = tree.query(x, k=k+1, workers=-1)
68
69
    return pdd_[:, 1:], cloud, inds[:, 1:]
70
71
72
def nearest_neighbours_minval(motif, cell, min_val):
73
    """The same as nearest_neighbours except a value is given instead of an
74
    integer k and the result has at least enough columns so all values in 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
75
    the last column are at least the given value."""
76
77
    cloud_generator = generate_concentric_cloud(motif, cell)
78
79
    cloud = []
80
    for _ in range(3):
81
        cloud.append(next(cloud_generator))
82
83
    cloud = np.concatenate(cloud)
84
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
85
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
86
    pdd = np.zeros_like(pdd_)
87
88
    while True:
89
        if np.all(pdd[:, -1] >= min_val):
90
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
91
            if np.array_equal(pdd[:, :col_where+1], pdd_[:, :col_where+1]):
92
                break
93
94
        pdd = pdd_
95
        cloud = np.vstack((cloud, next(cloud_generator)))
96
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
97
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
98
99
    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
100
101
    return pdd[:, 1:k+1]
102
103
104
def generate_concentric_cloud(
105
        motif: np.ndarray,
106
        cell: np.ndarray
107
) -> Iterable[np.ndarray]:
108
    """
109
    Generates batches of points from a periodic set given by (motif, cell)
110
    which get successively further away from the origin.
111
112
    Each yield gives all points (that have not already been yielded) which
113
    lie in a unit cell whose corner lattice point was generated by
114
    generate_integer_lattice(motif.shape[1]).
115
116
    Parameters
117
    ----------
118
    motif : ndarray
119
        Cartesian representation of the motif, shape (no points, dims).
120
    cell : ndarray
121
        Cartesian representation of the unit cell, shape (dims, dims).
122
123
    Yields
124
    -------
125
    ndarray
126
        Yields arrays of points from the periodic set.
127
    """
128
129
    m = len(motif)
130
    int_lattice_generator = generate_integer_lattice(cell.shape[0])
131
132
    while True:
133
        int_lattice = next(int_lattice_generator) @ cell
0 ignored issues
show
introduced by
Do not raise StopIteration in generator, use return statement instead
Loading history...
134
        layer = np.empty((m * len(int_lattice), cell.shape[0]))
135
136
        for i, translation in enumerate(int_lattice):
137
            layer[m*i:m*(i+1)] = motif + translation
138
139
        yield layer
140
141
142
def generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
143
    """Generates batches of integer lattice points.
144
145
    Each yield gives all points (that have not already been yielded)
146
    inside a sphere centered at the origin with radius d. d starts at 0
147
    and increments by 1 on each loop.
148
149
    Parameters
150
    ----------
151
    dims : int
152
        The dimension of Euclidean space the lattice is in.
153
154
    Yields
155
    -------
156
    ndarray
157
        Yields arrays of integer points in dims dimensional Euclidean space.
158
    """
159
160
    ymax = collections.defaultdict(int)
161
    d = 0
162
163
    if dims == 1:
164
        yield np.array([[0]])
165
        while True:
166
            d += 1
167
            yield np.array([[-d], [d]])
168
169
    while True:
170
        positive_int_lattice = []
171
        while True:
172
            batch = []
173
            for xy in product(range(d+1), repeat=dims-1):
174
                if _dist(xy, ymax[xy]) <= d**2:
175
                    batch.append((*xy, ymax[xy]))
176
                    ymax[xy] += 1
177
            if not batch:
178
                break
179
            positive_int_lattice += batch
180
181
        positive_int_lattice = np.array(positive_int_lattice)
182
        batches = _reflect_positive_lattice(positive_int_lattice)
183
        yield np.array(np.concatenate(batches))
184
        d += 1
185
186
187
@numba.njit()
188
def _reflect_positive_lattice(positive_int_lattice):
189
    """Reflect a set of points in the +ve quadrant in all axes."""
190
    dims = positive_int_lattice.shape[-1]
191
    batches = [positive_int_lattice]
192
193
    for n_reflections in range(1, dims + 1):
194
195
        indices = np.arange(n_reflections)
196
        batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
197
        batch[:, indices] *= -1
198
        batches.append(batch)
199
200
        while True:
201
            i = n_reflections - 1
202
            for _ in range(n_reflections):
203
                if indices[i] != i + dims - n_reflections:
204
                    break
205
                i -= 1
206
            else:
207
                break
208
            indices[i] += 1
209
            for j in range(i+1, n_reflections):
210
                indices[j] = indices[j-1] + 1
211
            
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
212
            batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
213
            batch[:, indices] *= -1
214
            batches.append(batch)
215
216
    return batches
217
218
219
@numba.njit()
220
def _dist(xy, z):
221
    s = z ** 2
222
    for val in xy:
223
        s += val ** 2
224
    return s
225
226
227
# # @numba.njit()
228
# def cartesian_product(n, repeat):
229
#     arrays = [np.arange(n)] * repeat
230
#     arr = np.empty(tuple([n] * repeat + [repeat]), dtype=np.int64)
231
#     for i, a in enumerate(np.ix_(*arrays)):
232
#         arr[..., i] = a
233
#     return arr.reshape(-1, repeat)
234