Test Failed
Push — master ( 37d7fb...c02a6e )
by Daniel
07:38
created

amd.reconstruct._remove_values_within_tol()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Functions for resconstructing a periodic set up to isometry from its
2
PDD. This is possible 'in a general position', see our papers for more.
3
"""
4
5
from itertools import combinations, permutations, product
6
7
import numpy as np
8
import numpy.typing as npt
9
import numba
10
from scipy.spatial.distance import cdist
11
from scipy.spatial import KDTree
12
13
from ._nearest_neighbours import generate_concentric_cloud
14
from .utils import diameter
15
16
17
def reconstruct(
18
        pdd: npt.NDArray,
19
        cell: npt.NDArray
20
) -> npt.NDArray:
21
    """Reconstruct a motif from a PDD and unit cell. This function will
22
    only work if ``pdd`` has enough columns, such that the last column
23
    has all values larger than 2 times the diameter of the unit cell. It
24
    also expects an uncollapsed PDD with no weights column. Do not use
25
    ``amd.PDD`` to compute the PDD for this function, instead use
26
    ``amd.PDD_reconstructable`` which returns a version of the PDD which
27
    is passable to this function. This function is quite slow and run
28
    time may vary a lot arbitrarily depending on input.
29
30
    Parameters
31
    ----------
32
    pdd : :class:`numpy.ndarray`
33
        The PDD of the periodic set to reconstruct. Needs `k` at least
34
        large enough so all values in the last column of pdd are greater
35
        than :code:`2 * diameter(cell)`, and needs to be uncollapsed
36
        without weights. Use amd.PDD_reconstructable to get a PDD which
37
        is acceptable for this argument.
38
    cell : :class:`numpy.ndarray`
39
        Unit cell of the periodic set to reconstruct.
40
41
    Returns
42
    -------
43
    :class:`numpy.ndarray`
44
        The reconstructed motif of the periodic set.
45
    """
46
47
    # TODO: get a more reduced neighbour set
48
    # TODO: find all shared distances in a big operation at the start
49
    # TODO: move PREC variable to its proper place
50
    PREC = 1e-10
51
52
    dims = cell.shape[0]
53
    if dims not in (2, 3):
54
        raise ValueError(
55
            'Reconstructing from PDD only implemented for 2 and 3 dimensions'
56
        )
57
58
    diam = diameter(cell)
59
    # set first point as origin wlog, return if 1 motif point
60
    motif = [np.zeros((dims, ))]
61
62
    if pdd.shape[0] == 1:
63
        motif = np.array(motif)
64
        return motif
65
66
    # finding lattice distances so we can ignore them
67
    cloud_generator = generate_concentric_cloud(np.array(motif), cell)
68
    cloud = []
69
    next(cloud_generator)
70
    layer = next(cloud_generator)
71
    while np.any(np.linalg.norm(layer, axis=-1) <= diam):
72
        cloud.append(layer)
73
        layer = next(cloud_generator)
74
    cloud = np.concatenate(cloud)
75
76
    # is (a superset of) lattice points close enough to Voronoi domain
77
    nn_set = _neighbour_set(cell, PREC)
78
    lattice_dists = np.linalg.norm(cloud, axis=-1)
79
    lattice_dists = lattice_dists[lattice_dists <= diam]
80
    lattice_dists = _unique_within_tol(lattice_dists, PREC)
81
82
    # remove lattice distances from first and second rows
83
    row1_reduced = _remove_vals(pdd[0], lattice_dists, PREC)
84
    row2_reduced = _remove_vals(pdd[1], lattice_dists, PREC)
85
    # get shared dists between first and second rows
86
    shared_dists = _shared_vals(row1_reduced, row2_reduced, PREC)
87
    shared_dists = _unique_within_tol(shared_dists, PREC)
88
89
    # all combinations of vecs in neighbour set forming a basis
90
    bases = []
91
    for vecs in combinations(nn_set, dims):
92
        vecs = np.asarray(vecs)
93
        if np.abs(np.linalg.det(vecs)) > PREC:
94
            bases.extend(basis for basis in permutations(vecs, dims))
95
96
    q = _find_second_point(shared_dists, bases, cloud, PREC)
97
98
    if q is None:
99
        raise RuntimeError('Second point of motif could not be found.')
100
101
    motif.append(q)
102
103
    if pdd.shape[0] == 2:
104
        motif = np.array(motif)
105
        return motif
106
107
    for row in pdd[2:, :]:
108
        row_reduced = _remove_vals(row, lattice_dists, PREC)
109
        shared_dists1 = _shared_vals(row1_reduced, row_reduced, PREC)
110
        shared_dists2 = _shared_vals(row2_reduced, row_reduced, PREC)
111
        shared_dists1 = _unique_within_tol(shared_dists1, PREC)
112
        shared_dists2 = _unique_within_tol(shared_dists2, PREC)
113
        q_ = _find_further_point(
114
            shared_dists1, shared_dists2, bases, cloud, q, PREC
115
        )
116
117
        if q_ is None:
118
            raise RuntimeError('Further point of motif could not be found.')
119
120
        motif.append(q_)
121
122
    motif = np.array(motif)
123
    motif = np.mod(motif @ np.linalg.inv(cell), 1) @ cell
124
    return motif
125
126
127
def _find_second_point(shared_dists, bases, cloud, prec):
128
    dims = cloud.shape[-1]
129
    abs_q = shared_dists[0]
130
131
    for distance_tup in combinations(shared_dists[1:], dims):
132
        for basis in bases:
133
            res = None
134
            if dims == 2:
135
                res = _bilaterate(*basis, *distance_tup, abs_q, prec)
136
            elif dims == 3:
137
                if not _four_sphere_pairwise_intersecion(
138
                    *basis, *distance_tup, abs_q, prec
139
                ):
140
                    continue
141
                res = _trilaterate(*basis, *distance_tup, abs_q, prec)
142
143
            if res is not None:
144
                cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
145
                if np.all(cloud_res_dists - abs_q + prec > 0):
146
                    return res
147
148
149
def _find_further_point(shared_dists1, shared_dists2, bases, cloud, q, prec):
150
    # distance from origin (first motif point) to further point
151
    dims = cloud.shape[-1]
152
    abs_q_ = shared_dists1[0]
153
154
    # try all ordered subsequences of distances shared between first and
155
    # further row, with all combinations of the vectors in the neighbour set
156
    # forming a basis
157
    for distance_tup in combinations(shared_dists1[1:], dims):
158
        for basis in bases:
159
            res = None
160
            if dims == 2:
161
                res = _bilaterate(*basis, *distance_tup, abs_q_, prec)
162
            elif dims == 3:
163
                if not _four_sphere_pairwise_intersecion(
164
                    *basis, *distance_tup, abs_q_, prec
165
                ):
166
                    continue
167
                res = _trilaterate(*basis, *distance_tup, abs_q_, prec)
168
169
            if res is not None:
170
                # check point is in the voronoi domain
171
                cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
172
                if np.all(cloud_res_dists - abs_q_ + prec > 0):
173
                    # check |p - point| is among the row's shared distances
174
                    dist_diff = np.abs(shared_dists2 - np.linalg.norm(q - res))
175
                    if np.any(dist_diff < prec):
176
                        return res
177
178
179
def _neighbour_set(cell, prec):
180
    """(A superset of) the neighbour set of origin for a lattice."""
181
182
    k_ = 5
183
    coeffs = np.array(list(product((-1, 0, 1), repeat=cell.shape[0])))
184
    coeffs = coeffs[coeffs.any(axis=-1)]    # remove (0,0,0)
185
186
    # half of all combinations of basis vectors
187
    vecs = []
188
    for c in coeffs:
189
        vecs.append(np.sum(cell * c[None, :].T, axis=0) / 2)
190
    vecs = np.array(vecs)
191
192
    origin = np.zeros((1, cell.shape[0]))
193
    cloud_generator = generate_concentric_cloud(origin, cell)
194
    cloud = np.concatenate((next(cloud_generator), next(cloud_generator)))
195
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
196
    dists, inds = tree.query(vecs, k=k_, workers=-1)
197
    dists_ = np.empty_like(dists)
198
199
    while not np.allclose(dists, dists_, atol=0, rtol=1e-12):
200
        dists = dists_
201
        cloud = np.vstack((cloud,
202
                           next(cloud_generator),
203
                           next(cloud_generator)))
204
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
205
        dists_, inds = tree.query(vecs, k=k_, workers=-1)
206
207
    tmp_inds = np.unique(inds[:, 1:].flatten())
208
    tmp_inds = tmp_inds[tmp_inds != 0]
209
    neighbour_set = cloud[tmp_inds]
210
211
    # reduce neighbour set
212
    # half the lattice points and find their nearest neighbours in the lattice
213
    neighbour_set_half = neighbour_set / 2
214
    # for each of these vectors, check if 0 is A nearest neighbour.
215
    # so, check if the dist to 0 is leq (within tol) than dist to all other
216
    # lattice points.
217
    nn_norms = np.linalg.norm(neighbour_set, axis=-1)
218
    halves_norms = nn_norms / 2
219
    halves_to_lattice_dists = cdist(neighbour_set_half, neighbour_set)
220
221
    # Do I need to + PREC?
222
    # inds of voronoi neighbours in cloud
223
    voronoi_neighbours = np.all(
224
        halves_to_lattice_dists - halves_norms + prec >= 0, axis=-1
225
    )
226
    neighbour_set = neighbour_set[voronoi_neighbours]
227
    return neighbour_set
228
229
230
def _four_sphere_pairwise_intersecion(p1, p2, p3, r1, r2, r3, abs_val, prec):
231
    """Return True if four spheres intersect at a point."""
232
233
    if np.linalg.norm(p1) > abs_val + r1 - prec:
234
        return False
235
    if np.linalg.norm(p2) > abs_val + r2 - prec:
236
        return False
237
    if np.linalg.norm(p3) > abs_val + r3 - prec:
238
        return False
239
    if np.linalg.norm(p1 - p2) > r1 + r2 - prec:
240
        return False
241
    if np.linalg.norm(p1 - p3) > r1 + r3 - prec:
242
        return False
243
    if np.linalg.norm(p2 - p3) > r2 + r3 - prec:
244
        return False
245
    return True
246
247
248
@numba.njit()
249
def _bilaterate(p1, p2, r1, r2, abs_val, prec):
250
    """Return True if three circles intersect at a point."""
251
252
    d = np.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)
253
    v = (p2 - p1) / d
254
255
    if d > r1 + r2:
256
        return None
257
    if d < abs(r1 - r2):
258
        return None
259
    if d == 0 and r1 == r2:
260
        return None
261
262
    a = (r1 ** 2 - r2 ** 2 + d ** 2) / (2 * d)
263
    h = np.sqrt(r1 ** 2 - a ** 2)
264
    x2 = p1[0] + a * v[0]
265
    y2 = p1[1] + a * v[1]
266
    x3 = x2 + h * v[1]
267
    y3 = y2 - h * v[0]
268
    x4 = x2 - h * v[1]
269
    y4 = y2 + h * v[0]
270
    q1 = np.array((x3, y3))
271
    q2 = np.array((x4, y4))
272
273
    if np.abs(np.sqrt(x3 ** 2 + y3 ** 2) - abs_val) < prec:
274
        return q1
275
    if np.abs(np.sqrt(x4 ** 2 + y4 ** 2) - abs_val) < prec:
276
        return q2
277
    return None
278
279
280
@numba.njit()
281
def _trilaterate(p1, p2, p3, r1, r2, r3, abs_val, prec):
282
    """Return the intersection of four spheres."""
283
284
    temp1 = p2 - p1
285
    d = np.linalg.norm(temp1)
286
    e_x = temp1 / d
287
    temp2 = p3 - p1
288
    i = np.dot(e_x, temp2)
289
    temp3 = temp2 - i * e_x
290
    e_y = temp3 / np.linalg.norm(temp3)
291
292
    j = np.dot(e_y, temp2)
293
    x = (r1 * r1 - r2 * r2 + d * d) / (2 * d)
294
    y = (r1 * r1 - r3 * r3 - 2 * i * x + i * i + j * j) / (2 * j)
295
    temp4 = r1 * r1 - x * x - y * y
296
297
    if temp4 < 0:
298
        return None
299
300
    e_z = np.cross(e_x, e_y)
301
    z = np.sqrt(temp4)
302
    p_12_a = p1 + x * e_x + y * e_y + z * e_z
303
    p_12_b = p1 + x * e_x + y * e_y - z * e_z
304
305
    if np.abs(np.linalg.norm(p_12_a) - abs_val) < prec:
306
        return p_12_a
307
    if np.abs(np.linalg.norm(p_12_b) - abs_val) < prec:
308
        return p_12_b
309
    return None
310
311
312
def _unique_within_tol(arr, prec):
313
    """Return only unique values in a vector within ``prec``."""
314
    return arr[~np.any(np.triu(np.abs(arr[:, None] - arr) < prec, 1), axis=0)]
315
316
317
def _remove_vals(vec, vals_to_remove, prec):
318
    """Remove specified values in vec, within ``prec``."""
319
    return vec[~np.any(np.abs(vec[:, None] - vals_to_remove) < prec, axis=-1)]
320
321
322
def _shared_vals(v1, v2, prec):
323
    """Return values shared between v1, v2 within ``prec``."""
324
    return v1[np.argwhere(np.abs(v1[:, None] - v2) < prec)[:, 0]]
325