Passed
Push — master ( 406217...8c12c2 )
by Daniel
03:55
created

amd.reconstruct._shared_vals()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Functions for resconstructing a periodic set up to isometry from its
2
PDD. This is possible 'in a general position', see our papers for more.
3
"""
4
5
from itertools import combinations, permutations, product
6
7
import numpy as np
8
import numba
9
from scipy.spatial.distance import cdist
10
from scipy.spatial import KDTree
11
12
from ._nearest_neighbours import generate_concentric_cloud
13
from .utils import diameter
14
15
__all__ = ['reconstruct']
16
17
18
def reconstruct(pdd: np.ndarray, cell: np.ndarray) -> np.ndarray:
19
    """Reconstruct a motif from a PDD and unit cell. This function will
20
    only work if ``pdd`` has enough columns, such that the last column
21
    has all values larger than 2 times the diameter of the unit cell. It
22
    also expects an uncollapsed PDD with no weights column. Do not use
23
    ``amd.PDD`` to compute the PDD for this function, instead use
24
    ``amd.PDD_reconstructable`` which returns a version of the PDD which
25
    is passable to this function. This function is quite slow and run
26
    time may vary a lot arbitrarily depending on input.
27
28
    Parameters
29
    ----------
30
    pdd : :class:`numpy.ndarray`
31
        The PDD of the periodic set to reconstruct. Needs `k` at least
32
        large enough so all values in the last column of pdd are greater
33
        than :code:`2 * diameter(cell)`, and needs to be uncollapsed
34
        without weights. Use amd.PDD_reconstructable to get a PDD which
35
        is acceptable for this argument.
36
    cell : :class:`numpy.ndarray`
37
        Unit cell of the periodic set to reconstruct.
38
39
    Returns
40
    -------
41
    :class:`numpy.ndarray`
42
        The reconstructed motif of the periodic set.
43
    """
44
45
    # TODO: get a more reduced neighbour set
46
    # TODO: find all shared distances in a big operation at the start
47
    # TODO: move PREC variable to its proper place
48
    PREC = 1e-10
49
50
    dims = cell.shape[0]
51
    if dims not in (2, 3):
52
        raise ValueError(
53
            'Reconstructing from PDD only implemented for 2 and 3 dimensions'
54
        )
55
    diam = diameter(cell)
56
    motif = [np.zeros((dims, ))]
57
    if pdd.shape[0] == 1:
58
        return np.array(motif)
59
60
    # finding lattice distances so we can ignore them
61
    cloud_generator = generate_concentric_cloud(np.array(motif), cell)
62
    next(cloud_generator)  # is just the origin
63
    cloud = []
64
    layer = next(cloud_generator)
65
    while np.any(np.linalg.norm(layer, axis=-1) <= diam):
66
        cloud.append(layer)
67
        layer = next(cloud_generator)
68
    cloud = np.concatenate(cloud)
69
70
    # is (a superset of) lattice points close enough to Voronoi domain
71
    nn_set = _neighbour_set(cell, PREC)
72
    lattice_dists = np.linalg.norm(cloud, axis=-1)
73
    lattice_dists = lattice_dists[lattice_dists <= diam]
74
    lattice_dists = _unique_within_tol(lattice_dists, PREC)
75
76
    # remove lattice distances from first and second rows
77
    row1_reduced = _remove_vals(pdd[0], lattice_dists, PREC)
78
    row2_reduced = _remove_vals(pdd[1], lattice_dists, PREC)
79
    # get shared dists between first and second rows
80
    shared_dists = _shared_vals(row1_reduced, row2_reduced, PREC)
81
    shared_dists = _unique_within_tol(shared_dists, PREC)
82
83
    # all combinations of vecs in neighbour set forming a basis
84
    bases = []
85
    for vecs in combinations(nn_set, dims):
86
        vecs = np.asarray(vecs)
87
        if np.abs(np.linalg.det(vecs)) > PREC:
88
            bases.extend(basis for basis in permutations(vecs, dims))
89
90
    q = _find_second_point(shared_dists, bases, cloud, PREC)
91
    if q is None:
92
        raise RuntimeError('Second point of motif could not be found.')
93
    motif.append(q)
94
95
    if pdd.shape[0] == 2:
96
        return np.array(motif)
97
98
    for row in pdd[2:, :]:
99
        row_reduced = _remove_vals(row, lattice_dists, PREC)
100
        shared_dists1 = _shared_vals(row1_reduced, row_reduced, PREC)
101
        shared_dists2 = _shared_vals(row2_reduced, row_reduced, PREC)
102
        shared_dists1 = _unique_within_tol(shared_dists1, PREC)
103
        shared_dists2 = _unique_within_tol(shared_dists2, PREC)
104
        q_ = _find_further_point(
105
            shared_dists1, shared_dists2, bases, cloud, q, PREC
106
        )
107
        if q_ is None:
108
            raise RuntimeError('Further point of motif could not be found.')
109
        motif.append(q_)
110
111
    motif = np.array(motif)
112
    motif = np.mod(motif @ np.linalg.inv(cell), 1) @ cell
113
    return motif
114
115
116
def _find_second_point(shared_dists, bases, cloud, prec):
117
    dims = cloud.shape[-1]
118
    abs_q = shared_dists[0]
119
    sphere_intersect_func = _trilaterate if dims == 3 else _bilaterate
120
121
    for distance_tup in combinations(shared_dists[1:], dims):
122
        for basis in bases:
123
            res = sphere_intersect_func(*basis, *distance_tup, abs_q, prec)
124
            if res is None:
125
                continue
126
            cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
127
            if np.all(cloud_res_dists - abs_q + prec > 0):
128
                return res
129
130
131
def _find_further_point(shared_dists1, shared_dists2, bases, cloud, q, prec):
132
    # distance from origin (first motif point) to further point
133
    dims = cloud.shape[-1]
134
    abs_q_ = shared_dists1[0]
135
136
    # try all ordered subsequences of distances shared between first and
137
    # further row, with all combinations of the vectors in the neighbour set
138
    # forming a basis, see if spheres centered at the vectors with the shared
139
    # distances as radii intersect at 4 (3 dims) points.
140
    sphere_intersect_func = _trilaterate if dims == 3 else _bilaterate
141
    for distance_tup in combinations(shared_dists1[1:], dims):
142
        for basis in bases:
143
            res = sphere_intersect_func(*basis, *distance_tup, abs_q_, prec)
144
            if res is None:
145
                continue
146
            # check point is in the voronoi domain
147
            cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
148
            if not np.all(cloud_res_dists - abs_q_ + prec > 0):
149
                continue
150
            # check |p - point| is among the row's shared distances
151
            dist_diff = np.abs(shared_dists2 - np.linalg.norm(q - res))
152
            if np.any(dist_diff < prec):
153
                return res
154
155
156
def _neighbour_set(cell, prec):
157
    """(A superset of) the neighbour set of origin for a lattice."""
158
159
    k_ = 5
160
    coeffs = np.array(list(product((-1, 0, 1), repeat=cell.shape[0])))
161
    coeffs = coeffs[coeffs.any(axis=-1)]    # remove (0,0,0)
162
163
    # half of all combinations of basis vectors
164
    vecs = []
165
    for c in coeffs:
166
        vecs.append(np.sum(cell * c[None, :].T, axis=0) / 2)
167
    vecs = np.array(vecs)
168
169
    origin = np.zeros((1, cell.shape[0]))
170
    cloud_generator = generate_concentric_cloud(origin, cell)
171
    cloud = np.concatenate((next(cloud_generator), next(cloud_generator)))
172
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
173
    dists, inds = tree.query(vecs, k=k_)
174
    dists_ = np.empty_like(dists)
175
176
    while not np.allclose(dists, dists_, atol=0, rtol=1e-12):
177
        dists = dists_
178
        cloud = np.vstack(
179
            (cloud, next(cloud_generator), next(cloud_generator))
180
        )
181
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
182
        dists_, inds = tree.query(vecs, k=k_)
183
184
    tmp_inds = np.unique(inds[:, 1:].flatten())
185
    tmp_inds = tmp_inds[tmp_inds != 0]
186
    neighbour_set = cloud[tmp_inds]
187
188
    # reduce neighbour set
189
    # half the lattice points and find their nearest neighbours in the lattice
190
    neighbour_set_half = neighbour_set / 2
191
    # for each of these vectors, check if 0 is a nearest neighbour.
192
    # so, check if the dist to 0 is leq (within tol) than dist to all other
193
    # lattice points.
194
    nn_norms = np.linalg.norm(neighbour_set, axis=-1)
195
    halves_norms = nn_norms / 2
196
    halves_to_lattice_dists = cdist(neighbour_set_half, neighbour_set)
197
198
    # Do I need to + PREC?
199
    # inds of voronoi neighbours in cloud
200
    voronoi_neighbours = np.all(
201
        halves_to_lattice_dists - halves_norms + prec >= 0, axis=-1
202
    )
203
    neighbour_set = neighbour_set[voronoi_neighbours]
204
    return neighbour_set
205
206
207
@numba.njit(cache=True)
208
def _bilaterate(p1, p2, r1, r2, abs_val, prec):
209
    """Return the intersection of three circles."""
210
211
    d = np.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)
212
    v = (p2 - p1) / d
213
214
    if d > r1 + r2:
215
        return None
216
    if d < abs(r1 - r2):
217
        return None
218
    if d == 0 and r1 == r2:
219
        return None
220
221
    a = (r1 ** 2 - r2 ** 2 + d ** 2) / (2 * d)
222
    h = np.sqrt(r1 ** 2 - a ** 2)
223
    x2 = p1[0] + a * v[0]
224
    y2 = p1[1] + a * v[1]
225
    x3 = x2 + h * v[1]
226
    y3 = y2 - h * v[0]
227
    x4 = x2 - h * v[1]
228
    y4 = y2 + h * v[0]
229
    q1 = np.array((x3, y3))
230
    q2 = np.array((x4, y4))
231
232
    if np.abs(np.sqrt(x3 ** 2 + y3 ** 2) - abs_val) < prec:
233
        return q1
234
    if np.abs(np.sqrt(x4 ** 2 + y4 ** 2) - abs_val) < prec:
235
        return q2
236
    return None
237
238
239
@numba.njit(cache=True)
240
def _trilaterate(p1, p2, p3, r1, r2, r3, abs_val, prec):
241
    """Return the intersection of four spheres."""
242
243
    if np.linalg.norm(p1) > abs_val + r1 - prec:
244
        return None
245
    if np.linalg.norm(p2) > abs_val + r2 - prec:
246
        return None
247
    if np.linalg.norm(p3) > abs_val + r3 - prec:
248
        return None
249
    if np.linalg.norm(p1 - p2) > r1 + r2 - prec:
250
        return None
251
    if np.linalg.norm(p1 - p3) > r1 + r3 - prec:
252
        return None
253
    if np.linalg.norm(p2 - p3) > r2 + r3 - prec:
254
        return None
255
256
    temp1 = p2 - p1
257
    d = np.linalg.norm(temp1)
258
    e_x = temp1 / d
259
    temp2 = p3 - p1
260
    i = np.dot(e_x, temp2)
261
    temp3 = temp2 - i * e_x
262
    e_y = temp3 / np.linalg.norm(temp3)
263
264
    j = np.dot(e_y, temp2)
265
    x = (r1 * r1 - r2 * r2 + d * d) / (2 * d)
266
    y = (r1 * r1 - r3 * r3 - 2 * i * x + i * i + j * j) / (2 * j)
267
    temp4 = r1 * r1 - x * x - y * y
268
269
    if temp4 < 0:
270
        return None
271
272
    e_z = np.cross(e_x, e_y)
273
    z = np.sqrt(temp4)
274
    p_12_a = p1 + x * e_x + y * e_y + z * e_z
275
    p_12_b = p1 + x * e_x + y * e_y - z * e_z
276
277
    if np.abs(np.linalg.norm(p_12_a) - abs_val) < prec:
278
        return p_12_a
279
    if np.abs(np.linalg.norm(p_12_b) - abs_val) < prec:
280
        return p_12_b
281
    return None
282
283
284
def _unique_within_tol(arr, prec):
285
    """Return only unique values in a vector within ``prec``."""
286
    return arr[~np.any(np.triu(np.abs(arr[:, None] - arr) < prec, 1), axis=0)]
287
288
289
def _remove_vals(vec, vals_to_remove, prec):
290
    """Remove specified values in vec, within ``prec``."""
291
    return vec[~np.any(np.abs(vec[:, None] - vals_to_remove) < prec, axis=-1)]
292
293
294
def _shared_vals(v1, v2, prec):
295
    """Return values shared between v1, v2 within ``prec``."""
296
    return v1[np.argwhere(np.abs(v1[:, None] - v2) < prec)[:, 0]]
297