Test Failed
Push — master ( 37d7fb...c02a6e )
by Daniel
07:38
created

amd.reconstruct._trilaterate()   A

Complexity

Conditions 4

Size

Total Lines 30
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 30
rs 9.304
c 0
b 0
f 0
cc 4
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Functions for resconstructing a periodic set up to isometry from its
2
PDD. This is possible 'in a general position', see our papers for more.
3
"""
4
5
from itertools import combinations, permutations, product
6
7
import numpy as np
8
import numpy.typing as npt
9
import numba
10
from scipy.spatial.distance import cdist
11
from scipy.spatial import KDTree
12
13
from ._nearest_neighbours import generate_concentric_cloud
14
from .utils import diameter
15
16
17
def reconstruct(
18
        pdd: npt.NDArray,
19
        cell: npt.NDArray
20
) -> npt.NDArray:
21
    """Reconstruct a motif from a PDD and unit cell. This function will
22
    only work if ``pdd`` has enough columns, such that the last column
23
    has all values larger than 2 times the diameter of the unit cell. It
24
    also expects an uncollapsed PDD with no weights column. Do not use
25
    ``amd.PDD`` to compute the PDD for this function, instead use
26
    ``amd.PDD_reconstructable`` which returns a version of the PDD which
27
    is passable to this function. This function is quite slow and run
28
    time may vary a lot arbitrarily depending on input.
29
30
    Parameters
31
    ----------
32
    pdd : :class:`numpy.ndarray`
33
        The PDD of the periodic set to reconstruct. Needs `k` at least
34
        large enough so all values in the last column of pdd are greater
35
        than :code:`2 * diameter(cell)`, and needs to be uncollapsed
36
        without weights. Use amd.PDD_reconstructable to get a PDD which
37
        is acceptable for this argument.
38
    cell : :class:`numpy.ndarray`
39
        Unit cell of the periodic set to reconstruct.
40
41
    Returns
42
    -------
43
    :class:`numpy.ndarray`
44
        The reconstructed motif of the periodic set.
45
    """
46
47
    # TODO: get a more reduced neighbour set
48
    # TODO: find all shared distances in a big operation at the start
49
    # TODO: move PREC variable to its proper place
50
    PREC = 1e-10
51
52
    dims = cell.shape[0]
53
    if dims not in (2, 3):
54
        raise ValueError(
55
            'Reconstructing from PDD only implemented for 2 and 3 dimensions'
56
        )
57
58
    diam = diameter(cell)
59
    # set first point as origin wlog, return if 1 motif point
60
    motif = [np.zeros((dims, ))]
61
62
    if pdd.shape[0] == 1:
63
        motif = np.array(motif)
64
        return motif
65
66
    # finding lattice distances so we can ignore them
67
    cloud_generator = generate_concentric_cloud(np.array(motif), cell)
68
    cloud = []
69
    next(cloud_generator)
70
    layer = next(cloud_generator)
71
    while np.any(np.linalg.norm(layer, axis=-1) <= diam):
72
        cloud.append(layer)
73
        layer = next(cloud_generator)
74
    cloud = np.concatenate(cloud)
75
76
    # is (a superset of) lattice points close enough to Voronoi domain
77
    nn_set = _neighbour_set(cell, PREC)
78
    lattice_dists = np.linalg.norm(cloud, axis=-1)
79
    lattice_dists = lattice_dists[lattice_dists <= diam]
80
    lattice_dists = _unique_within_tol(lattice_dists, PREC)
81
82
    # remove lattice distances from first and second rows
83
    row1_reduced = _remove_vals(pdd[0], lattice_dists, PREC)
84
    row2_reduced = _remove_vals(pdd[1], lattice_dists, PREC)
85
    # get shared dists between first and second rows
86
    shared_dists = _shared_vals(row1_reduced, row2_reduced, PREC)
87
    shared_dists = _unique_within_tol(shared_dists, PREC)
88
89
    # all combinations of vecs in neighbour set forming a basis
90
    bases = []
91
    for vecs in combinations(nn_set, dims):
92
        vecs = np.asarray(vecs)
93
        if np.abs(np.linalg.det(vecs)) > PREC:
94
            bases.extend(basis for basis in permutations(vecs, dims))
95
96
    q = _find_second_point(shared_dists, bases, cloud, PREC)
97
98
    if q is None:
99
        raise RuntimeError('Second point of motif could not be found.')
100
101
    motif.append(q)
102
103
    if pdd.shape[0] == 2:
104
        motif = np.array(motif)
105
        return motif
106
107
    for row in pdd[2:, :]:
108
        row_reduced = _remove_vals(row, lattice_dists, PREC)
109
        shared_dists1 = _shared_vals(row1_reduced, row_reduced, PREC)
110
        shared_dists2 = _shared_vals(row2_reduced, row_reduced, PREC)
111
        shared_dists1 = _unique_within_tol(shared_dists1, PREC)
112
        shared_dists2 = _unique_within_tol(shared_dists2, PREC)
113
        q_ = _find_further_point(
114
            shared_dists1, shared_dists2, bases, cloud, q, PREC
115
        )
116
117
        if q_ is None:
118
            raise RuntimeError('Further point of motif could not be found.')
119
120
        motif.append(q_)
121
122
    motif = np.array(motif)
123
    motif = np.mod(motif @ np.linalg.inv(cell), 1) @ cell
124
    return motif
125
126
127
def _find_second_point(shared_dists, bases, cloud, prec):
128
    dims = cloud.shape[-1]
129
    abs_q = shared_dists[0]
130
131
    for distance_tup in combinations(shared_dists[1:], dims):
132
        for basis in bases:
133
            res = None
134
            if dims == 2:
135
                res = _bilaterate(*basis, *distance_tup, abs_q, prec)
136
            elif dims == 3:
137
                if not _four_sphere_pairwise_intersecion(
138
                    *basis, *distance_tup, abs_q, prec
139
                ):
140
                    continue
141
                res = _trilaterate(*basis, *distance_tup, abs_q, prec)
142
143
            if res is not None:
144
                cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
145
                if np.all(cloud_res_dists - abs_q + prec > 0):
146
                    return res
147
148
149
def _find_further_point(shared_dists1, shared_dists2, bases, cloud, q, prec):
150
    # distance from origin (first motif point) to further point
151
    dims = cloud.shape[-1]
152
    abs_q_ = shared_dists1[0]
153
154
    # try all ordered subsequences of distances shared between first and
155
    # further row, with all combinations of the vectors in the neighbour set
156
    # forming a basis
157
    for distance_tup in combinations(shared_dists1[1:], dims):
158
        for basis in bases:
159
            res = None
160
            if dims == 2:
161
                res = _bilaterate(*basis, *distance_tup, abs_q_, prec)
162
            elif dims == 3:
163
                if not _four_sphere_pairwise_intersecion(
164
                    *basis, *distance_tup, abs_q_, prec
165
                ):
166
                    continue
167
                res = _trilaterate(*basis, *distance_tup, abs_q_, prec)
168
169
            if res is not None:
170
                # check point is in the voronoi domain
171
                cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
172
                if np.all(cloud_res_dists - abs_q_ + prec > 0):
173
                    # check |p - point| is among the row's shared distances
174
                    dist_diff = np.abs(shared_dists2 - np.linalg.norm(q - res))
175
                    if np.any(dist_diff < prec):
176
                        return res
177
178
179
def _neighbour_set(cell, prec):
180
    """(A superset of) the neighbour set of origin for a lattice."""
181
182
    k_ = 5
183
    coeffs = np.array(list(product((-1, 0, 1), repeat=cell.shape[0])))
184
    coeffs = coeffs[coeffs.any(axis=-1)]    # remove (0,0,0)
185
186
    # half of all combinations of basis vectors
187
    vecs = []
188
    for c in coeffs:
189
        vecs.append(np.sum(cell * c[None, :].T, axis=0) / 2)
190
    vecs = np.array(vecs)
191
192
    origin = np.zeros((1, cell.shape[0]))
193
    cloud_generator = generate_concentric_cloud(origin, cell)
194
    cloud = np.concatenate((next(cloud_generator), next(cloud_generator)))
195
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
196
    dists, inds = tree.query(vecs, k=k_, workers=-1)
197
    dists_ = np.empty_like(dists)
198
199
    while not np.allclose(dists, dists_, atol=0, rtol=1e-12):
200
        dists = dists_
201
        cloud = np.vstack((cloud,
202
                           next(cloud_generator),
203
                           next(cloud_generator)))
204
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
205
        dists_, inds = tree.query(vecs, k=k_, workers=-1)
206
207
    tmp_inds = np.unique(inds[:, 1:].flatten())
208
    tmp_inds = tmp_inds[tmp_inds != 0]
209
    neighbour_set = cloud[tmp_inds]
210
211
    # reduce neighbour set
212
    # half the lattice points and find their nearest neighbours in the lattice
213
    neighbour_set_half = neighbour_set / 2
214
    # for each of these vectors, check if 0 is A nearest neighbour.
215
    # so, check if the dist to 0 is leq (within tol) than dist to all other
216
    # lattice points.
217
    nn_norms = np.linalg.norm(neighbour_set, axis=-1)
218
    halves_norms = nn_norms / 2
219
    halves_to_lattice_dists = cdist(neighbour_set_half, neighbour_set)
220
221
    # Do I need to + PREC?
222
    # inds of voronoi neighbours in cloud
223
    voronoi_neighbours = np.all(
224
        halves_to_lattice_dists - halves_norms + prec >= 0, axis=-1
225
    )
226
    neighbour_set = neighbour_set[voronoi_neighbours]
227
    return neighbour_set
228
229
230
def _four_sphere_pairwise_intersecion(p1, p2, p3, r1, r2, r3, abs_val, prec):
231
    """Return True if four spheres intersect at a point."""
232
233
    if np.linalg.norm(p1) > abs_val + r1 - prec:
234
        return False
235
    if np.linalg.norm(p2) > abs_val + r2 - prec:
236
        return False
237
    if np.linalg.norm(p3) > abs_val + r3 - prec:
238
        return False
239
    if np.linalg.norm(p1 - p2) > r1 + r2 - prec:
240
        return False
241
    if np.linalg.norm(p1 - p3) > r1 + r3 - prec:
242
        return False
243
    if np.linalg.norm(p2 - p3) > r2 + r3 - prec:
244
        return False
245
    return True
246
247
248
@numba.njit()
249
def _bilaterate(p1, p2, r1, r2, abs_val, prec):
250
    """Return True if three circles intersect at a point."""
251
252
    d = np.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)
253
    v = (p2 - p1) / d
254
255
    if d > r1 + r2:
256
        return None
257
    if d < abs(r1 - r2):
258
        return None
259
    if d == 0 and r1 == r2:
260
        return None
261
262
    a = (r1 ** 2 - r2 ** 2 + d ** 2) / (2 * d)
263
    h = np.sqrt(r1 ** 2 - a ** 2)
264
    x2 = p1[0] + a * v[0]
265
    y2 = p1[1] + a * v[1]
266
    x3 = x2 + h * v[1]
267
    y3 = y2 - h * v[0]
268
    x4 = x2 - h * v[1]
269
    y4 = y2 + h * v[0]
270
    q1 = np.array((x3, y3))
271
    q2 = np.array((x4, y4))
272
273
    if np.abs(np.sqrt(x3 ** 2 + y3 ** 2) - abs_val) < prec:
274
        return q1
275
    if np.abs(np.sqrt(x4 ** 2 + y4 ** 2) - abs_val) < prec:
276
        return q2
277
    return None
278
279
280
@numba.njit()
281
def _trilaterate(p1, p2, p3, r1, r2, r3, abs_val, prec):
282
    """Return the intersection of four spheres."""
283
284
    temp1 = p2 - p1
285
    d = np.linalg.norm(temp1)
286
    e_x = temp1 / d
287
    temp2 = p3 - p1
288
    i = np.dot(e_x, temp2)
289
    temp3 = temp2 - i * e_x
290
    e_y = temp3 / np.linalg.norm(temp3)
291
292
    j = np.dot(e_y, temp2)
293
    x = (r1 * r1 - r2 * r2 + d * d) / (2 * d)
294
    y = (r1 * r1 - r3 * r3 - 2 * i * x + i * i + j * j) / (2 * j)
295
    temp4 = r1 * r1 - x * x - y * y
296
297
    if temp4 < 0:
298
        return None
299
300
    e_z = np.cross(e_x, e_y)
301
    z = np.sqrt(temp4)
302
    p_12_a = p1 + x * e_x + y * e_y + z * e_z
303
    p_12_b = p1 + x * e_x + y * e_y - z * e_z
304
305
    if np.abs(np.linalg.norm(p_12_a) - abs_val) < prec:
306
        return p_12_a
307
    if np.abs(np.linalg.norm(p_12_b) - abs_val) < prec:
308
        return p_12_b
309
    return None
310
311
312
def _unique_within_tol(arr, prec):
313
    """Return only unique values in a vector within ``prec``."""
314
    return arr[~np.any(np.triu(np.abs(arr[:, None] - arr) < prec, 1), axis=0)]
315
316
317
def _remove_vals(vec, vals_to_remove, prec):
318
    """Remove specified values in vec, within ``prec``."""
319
    return vec[~np.any(np.abs(vec[:, None] - vals_to_remove) < prec, axis=-1)]
320
321
322
def _shared_vals(v1, v2, prec):
323
    """Return values shared between v1, v2 within ``prec``."""
324
    return v1[np.argwhere(np.abs(v1[:, None] - v2) < prec)[:, 0]]
325