Passed
Branch master (3a171e)
by Daniel
01:43
created

amd._nearest_neighbours._distkey()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 2
nop 1
1
"""Implements core function nearest_neighbours used for AMD and PDD calculations."""
2
3
from typing import Iterable, Optional
4
import itertools
5
import collections
6
7
import numba
8
import numpy as np
9
import scipy.spatial
10
11
12
@numba.njit()
13
def _dist(xy, z):
14
    s = z ** 2
15
    for val in xy:
16
        s += val ** 2
17
    return s
18
19
@numba.njit()
20
def _distkey(pt):
21
    s = 0
22
    for val in pt:
23
        s += val ** 2
24
    return s
25
26
27
def generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
28
    """Generates batches of integer lattice points.
29
30
    Each yield gives all points (that have not already been yielded)
31
    inside a sphere centered at the origin with radius d. d starts at 0
32
    and increments by 1 on each loop.
33
34
    Parameters
35
    ----------
36
    dims : int
37
        The dimension of Euclidean space the lattice is in.
38
39
    Yields
40
    -------
41
    ndarray
42
        Yields arrays of integer points in dims dimensional Euclidean space.
43
    """
44
45
    ymax = collections.defaultdict(int)
46
    d = 0
47
48
    if dims == 1:
49
        yield np.array([[0]])
50
        while True:
51
            d += 1
52
            yield np.array([[-d], [d]])
53
54
    while True:
0 ignored issues
show
unused-code introduced by
Too many nested blocks (6/5)
Loading history...
55
        # get integer lattice points in +ve directions
56
        positive_int_lattice = []
57
        while True:
58
            batch = []
59
            for xy in itertools.product(range(d+1), repeat=dims-1):
60
                if _dist(xy, ymax[xy]) <= d**2:
61
                    batch.append((*xy, ymax[xy]))
62
                    ymax[xy] += 1
63
            if not batch:
64
                break
65
            positive_int_lattice += batch
66
        positive_int_lattice.sort(key=_distkey)
67
68
        # expand +ve integer lattice to full lattice with reflections
69
        int_lattice = []
70
        for p in positive_int_lattice:
71
            int_lattice.append(p)
72
            for n_reflections in range(1, dims+1):
73
                for indexes in itertools.combinations(range(dims), n_reflections):
74
                    if all((p[i] for i in indexes)):
0 ignored issues
show
introduced by
The variable i does not seem to be defined for all execution paths.
Loading history...
75
                        p_ = list(p)
76
                        for i in indexes:
77
                            p_[i] *= -1
78
                        int_lattice.append(p_)
79
80
        yield np.array(int_lattice)
81
        d += 1
82
83
84
def generate_concentric_cloud(
85
        motif: np.ndarray,
86
        cell: np.ndarray
87
) -> Iterable[np.ndarray]:
88
    """
89
    Generates batches of points from a periodic set given by (motif, cell)
90
    which get successively further away from the origin.
91
92
    Each yield gives all points (that have not already been yielded) which
93
    lie in a unit cell whose corner lattice point was generated by
94
    _generate_integer_lattice(motif.shape[1]).
95
96
    Parameters
97
    ----------
98
    motif : ndarray
99
        Cartesian representation of the motif, shape (no points, dims).
100
    cell : ndarray
101
        Cartesian representation of the unit cell, shape (dims, dims).
102
103
    Yields
104
    -------
105
    ndarray
106
        Yields arrays of points from the periodic set.
107
    """
108
109
    int_lattice_generator = generate_integer_lattice(cell.shape[0])
110
111
    while True:
112
        int_lattice = next(int_lattice_generator) @ cell
0 ignored issues
show
introduced by
Do not raise StopIteration in generator, use return statement instead
Loading history...
113
        yield np.concatenate([motif + translation for translation in int_lattice])
114
115
116
def nearest_neighbours(
117
        motif: np.ndarray,
118
        cell: np.ndarray,
119
        k: int,
120
        asymmetric_unit: Optional[np.ndarray] = None):
121
    """
122
    Given a periodic set represented by (motif, cell) and an integer k, find
123
    the k nearest neighbours of the motif points in the periodic set.
124
125
    Note that cloud and inds are not used yet but may be in the future.
126
127
    Parameters
128
    ----------
129
    motif : ndarray
130
        Cartesian coords of the full motif, shape (no points, dims).
131
    cell : ndarray
132
        Cartesian coords of the unit cell, shape (dims, dims).
133
    k : int
134
        Number of nearest neighbours to find for each motif point.
135
    asymmetric_unit : ndarray, optional
136
        Indices pointing to an asymmetric unit in motif.
137
138
    Returns
139
    -------
140
    pdd : ndarray
141
        An array shape (motif.shape[0], k) of distances from each motif
142
        point to its k nearest neighbours in order. Points do not count
143
        as their own nearest neighbour. E.g. the distance to the n-th
144
        nearest neighbour of the m-th motif point is pdd[m][n].
145
    cloud : ndarray
146
        The collection of points in the periodic set that were generated
147
        during the nearest neighbour search.
148
    inds : ndarray
149
        An array shape (motif.shape[0], k) containing the indices of
150
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
151
        the m-th motif point is cloud[inds[m][n]].
152
    """
153
154
    if asymmetric_unit is not None:
155
        asym_unit = motif[asymmetric_unit]
156
    else:
157
        asym_unit = motif
158
159
    cloud_generator = generate_concentric_cloud(motif, cell)
160
    n_points = 0
161
    cloud = []
162
    while n_points <= k:
163
        l = next(cloud_generator)
164
        n_points += l.shape[0]
165
        cloud.append(l)
166
    cloud.append(next(cloud_generator))
167
    cloud = np.concatenate(cloud)
168
169
    tree = scipy.spatial.KDTree(cloud,
170
                                compact_nodes=False,
171
                                balanced_tree=False)
172
    pdd_, inds = tree.query(asym_unit, k=k+1, workers=-1)
173
    pdd = np.zeros_like(pdd_)
174
175
    while not np.allclose(pdd, pdd_, atol=1e-12, rtol=0):
176
        pdd = pdd_
177
        cloud = np.vstack((cloud,
178
                           next(cloud_generator),
179
                           next(cloud_generator)))
180
        tree = scipy.spatial.KDTree(cloud,
181
                                    compact_nodes=False,
182
                                    balanced_tree=False)
183
        pdd_, inds = tree.query(asym_unit, k=k+1, workers=-1)
184
185
    return pdd_[:, 1:], cloud, inds[:, 1:]
186
187
188
def nearest_neighbours_minval(motif, cell, min_val):
189
    """PDD large enough to be reconstructed from
190
    (such that last col values all > 2 * diam(cell))."""
191
192
    cloud_generator = generate_concentric_cloud(motif, cell)
193
194
    cloud = []
195
    for _ in range(3):
196
        cloud.append(next(cloud_generator))
197
198
    cloud = np.concatenate(cloud)
199
    tree = scipy.spatial.KDTree(cloud,
200
                                compact_nodes=False,
201
                                balanced_tree=False)
202
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
203
    pdd = np.zeros_like(pdd_)
204
205
    while True:
206
        if np.all(pdd[:, -1] >= min_val):
207
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
208
            if np.array_equal(pdd[:, :col_where+1], pdd_[:, :col_where+1]):
209
                break
210
211
        pdd = pdd_
212
        cloud = np.vstack((cloud,
213
                           next(cloud_generator),
214
                           next(cloud_generator)))
215
        tree = scipy.spatial.KDTree(cloud,
216
                                    compact_nodes=False,
217
                                    balanced_tree=False)
218
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
219
220
    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
221
222
    return pdd[:, 1:k+1]
223