amd.reconstruct._trilaterate()   C
last analyzed

Complexity

Conditions 10

Size

Total Lines 43
Code Lines 36

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 36
dl 0
loc 43
rs 5.9999
c 0
b 0
f 0
cc 10
nop 8

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like amd.reconstruct._trilaterate() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Functions for resconstructing a periodic set up to isometry from its
2
PDD. This is possible 'in a general position', see our papers for more.
3
"""
4
5
from itertools import combinations, permutations, product
6
7
import numpy as np
8
import numba
9
from scipy.spatial.distance import cdist
10
from scipy.spatial import KDTree
11
12
from ._nearest_neighbors import generate_concentric_cloud
13
from .utils import diameter
14
15
from ._types import FloatArray
16
17
__all__ = ["reconstruct"]
18
19
20
def reconstruct(pdd: FloatArray, cell: FloatArray) -> FloatArray:
21
    """Reconstruct a motif from a PDD and unit cell. This function will
22
    only work if ``pdd`` has enough columns, such that the last column
23
    has all values larger than 2 times the diameter of the unit cell. It
24
    also expects an uncollapsed PDD with no weights column. Do not use
25
    ``amd.PDD`` to compute the PDD for this function, instead use
26
    ``amd.PDD_reconstructable`` which returns a version of the PDD which
27
    is passable to this function. This function is quite slow and run
28
    time may vary a lot arbitrarily depending on input.
29
30
    Parameters
31
    ----------
32
    pdd : :class:`numpy.ndarray`
33
        The PDD of the periodic set to reconstruct. Needs `k` at least
34
        large enough so all values in the last column of pdd are greater
35
        than :code:`2 * diameter(cell)`, and needs to be uncollapsed
36
        without weights. Use amd.PDD_reconstructable to get a PDD which
37
        is acceptable for this argument.
38
    cell : :class:`numpy.ndarray`
39
        Unit cell of the periodic set to reconstruct.
40
41
    Returns
42
    -------
43
    :class:`numpy.ndarray`
44
        The reconstructed motif of the periodic set.
45
    """
46
47
    # TODO: get a more reduced neighbor set
48
    # TODO: find all shared distances in a big operation at the start
49
    # TODO: move PREC variable to its proper place
50
    PREC = 1e-10
51
52
    dims = cell.shape[0]
53
    if dims not in (2, 3):
54
        raise ValueError(
55
            "Reconstructing from PDD only implemented for 2 and 3 dimensions"
56
        )
57
    diam = diameter(cell)
58
    motif = [np.zeros((dims,))]
59
    if pdd.shape[0] == 1:
60
        return np.array(motif)
61
62
    # finding lattice distances so we can ignore them
63
    cloud_generator = iter(generate_concentric_cloud(np.array(motif), cell))
64
    next(cloud_generator)  # the origin
65
    cloud = []
66
    layer = next(cloud_generator)
67
    while np.any(np.linalg.norm(layer, axis=-1) <= diam):
68
        cloud.append(layer)
69
        layer = next(cloud_generator)
70
    cloud = np.concatenate(cloud)
71
72
    # is (a superset of) lattice points close enough to Voronoi domain
73
    nn_set = _neighbor_set(cell, PREC)
74
    lattice_dists = np.linalg.norm(cloud, axis=-1)
75
    lattice_dists = lattice_dists[lattice_dists <= diam]
76
    lattice_dists = _unique_within_tol(lattice_dists, PREC)
77
78
    # remove lattice distances from first and second rows
79
    row1_reduced = _remove_vals(pdd[0], lattice_dists, PREC)
80
    row2_reduced = _remove_vals(pdd[1], lattice_dists, PREC)
81
    # get shared dists between first and second rows
82
    shared_dists = _shared_vals(row1_reduced, row2_reduced, PREC)
83
    shared_dists = _unique_within_tol(shared_dists, PREC)
84
85
    # all combinations of vecs in neighbor set forming a basis
86
    bases = []
87
    for vecs in combinations(nn_set, dims):
88
        vecs = np.asarray(vecs)
89
        if np.abs(np.linalg.det(vecs)) > PREC:
90
            bases.extend(basis for basis in permutations(vecs, dims))
91
92
    q = _find_second_point(shared_dists, bases, cloud, PREC)
93
    if q is None:
94
        raise RuntimeError("Second point of motif could not be found.")
95
    motif.append(q)
96
97
    if pdd.shape[0] == 2:
98
        return np.array(motif)
99
100
    for row in pdd[2:, :]:
101
        row_reduced = _remove_vals(row, lattice_dists, PREC)
102
        shared_dists1 = _shared_vals(row1_reduced, row_reduced, PREC)
103
        shared_dists2 = _shared_vals(row2_reduced, row_reduced, PREC)
104
        shared_dists1 = _unique_within_tol(shared_dists1, PREC)
105
        shared_dists2 = _unique_within_tol(shared_dists2, PREC)
106
        q_ = _find_further_point(shared_dists1, shared_dists2, bases, cloud, q, PREC)
107
        if q_ is None:
108
            raise RuntimeError("Further point of motif could not be found.")
109
        motif.append(q_)
110
111
    motif = np.array(motif)
112
    motif = np.mod(motif @ np.linalg.inv(cell), 1) @ cell
113
    return motif
114
115
116
def _find_second_point(shared_dists, bases, cloud, prec):
117
    dims = cloud.shape[-1]
118
    abs_q = shared_dists[0]
119
    sphere_intersect_func = _trilaterate if dims == 3 else _bilaterate
120
121
    for distance_tup in combinations(shared_dists[1:], dims):
122
        for basis in bases:
123
            res = sphere_intersect_func(*basis, *distance_tup, abs_q, prec)
124
            if res is None:
125
                continue
126
            cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
127
            if np.all(cloud_res_dists - abs_q + prec > 0):
128
                return res
129
130
131
def _find_further_point(shared_dists1, shared_dists2, bases, cloud, q, prec):
132
    # distance from origin (first motif point) to further point
133
    dims = cloud.shape[-1]
134
    abs_q_ = shared_dists1[0]
135
136
    # try all ordered subsequences of distances shared between first and
137
    # further row, with all combinations of the vectors in the neighbor set
138
    # forming a basis, see if spheres centered at the vectors with the shared
139
    # distances as radii intersect at 4 (3 dims) points.
140
    sphere_intersect_func = _trilaterate if dims == 3 else _bilaterate
141
    for distance_tup in combinations(shared_dists1[1:], dims):
142
        for basis in bases:
143
            res = sphere_intersect_func(*basis, *distance_tup, abs_q_, prec)
144
            if res is None:
145
                continue
146
            # check point is in the voronoi domain
147
            cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
148
            if not np.all(cloud_res_dists - abs_q_ + prec > 0):
149
                continue
150
            # check |p - point| is among the row's shared distances
151
            dist_diff = np.abs(shared_dists2 - np.linalg.norm(q - res))
152
            if np.any(dist_diff < prec):
153
                return res
154
155
156
def _neighbor_set(cell, prec):
157
    """(A superset of) the neighbor set of origin for a lattice."""
158
159
    k_ = 5
160
    coeffs = np.array(list(product((-1, 0, 1), repeat=cell.shape[0])))
161
    coeffs = coeffs[coeffs.any(axis=-1)]  # remove (0,0,0)
162
163
    # half of all combinations of basis vectors
164
    vecs = []
165
    for c in coeffs:
166
        vecs.append(np.sum(cell * c[None, :].T, axis=0) / 2)
167
    vecs = np.array(vecs)
168
169
    origin = np.zeros((1, cell.shape[0]))
170
    cloud_generator = iter(generate_concentric_cloud(origin, cell))
171
    cloud = np.concatenate((next(cloud_generator), next(cloud_generator)))
172
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
173
    dists, inds = tree.query(vecs, k=k_)
174
    dists_ = np.empty_like(dists)
175
176
    while not np.allclose(dists, dists_, atol=0, rtol=1e-12):
177
        dists = dists_
178
        cloud = np.vstack((cloud, next(cloud_generator), next(cloud_generator)))
179
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
180
        dists_, inds = tree.query(vecs, k=k_)
181
182
    tmp_inds = np.unique(inds[:, 1:].flatten())
183
    tmp_inds = tmp_inds[tmp_inds != 0]
184
    neighbor_set = cloud[tmp_inds]
185
186
    # reduce neighbor set
187
    # half the lattice points and find their nearest neighbors in the lattice
188
    neighbor_set_half = neighbor_set / 2
189
    # for each of these vectors, check if 0 is a nearest neighbor.
190
    # so, check if the dist to 0 is leq (within tol) than dist to all other
191
    # lattice points.
192
    nn_norms = np.linalg.norm(neighbor_set, axis=-1)
193
    halves_norms = nn_norms / 2
194
    halves_to_lattice_dists = cdist(neighbor_set_half, neighbor_set)
195
196
    # Do I need to + PREC?
197
    # inds of voronoi neighbors in cloud
198
    voronoi_neighbors = np.all(
199
        halves_to_lattice_dists - halves_norms + prec >= 0, axis=-1
200
    )
201
    neighbor_set = neighbor_set[voronoi_neighbors]
202
    return neighbor_set
203
204
205
@numba.njit(cache=True, fastmath=True)
206
def _bilaterate(p1, p2, r1, r2, abs_val, prec):
207
    """Return the intersection of three circles."""
208
209
    d = np.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)
210
    v = (p2 - p1) / d
211
212
    if d > r1 + r2:
213
        return None
214
    if d < abs(r1 - r2):
215
        return None
216
    if d == 0 and r1 == r2:
217
        return None
218
219
    a = (r1**2 - r2**2 + d**2) / (2 * d)
220
    h = np.sqrt(r1**2 - a**2)
221
    x2 = p1[0] + a * v[0]
222
    y2 = p1[1] + a * v[1]
223
    x3 = x2 + h * v[1]
224
    y3 = y2 - h * v[0]
225
    x4 = x2 - h * v[1]
226
    y4 = y2 + h * v[0]
227
    q1 = np.array((x3, y3))
228
    q2 = np.array((x4, y4))
229
230
    if np.abs(np.sqrt(x3**2 + y3**2) - abs_val) < prec:
231
        return q1
232
    if np.abs(np.sqrt(x4**2 + y4**2) - abs_val) < prec:
233
        return q2
234
    return None
235
236
237
@numba.njit(cache=True)
238
def _trilaterate(p1, p2, p3, r1, r2, r3, abs_val, prec):
239
    """Return the intersection of four spheres."""
240
241
    if np.linalg.norm(p1) > abs_val + r1 - prec:
242
        return None
243
    if np.linalg.norm(p2) > abs_val + r2 - prec:
244
        return None
245
    if np.linalg.norm(p3) > abs_val + r3 - prec:
246
        return None
247
    if np.linalg.norm(p1 - p2) > r1 + r2 - prec:
248
        return None
249
    if np.linalg.norm(p1 - p3) > r1 + r3 - prec:
250
        return None
251
    if np.linalg.norm(p2 - p3) > r2 + r3 - prec:
252
        return None
253
254
    temp1 = p2 - p1
255
    d = np.linalg.norm(temp1)
256
    e_x = temp1 / d
257
    temp2 = p3 - p1
258
    i = np.dot(e_x, temp2)
259
    temp3 = temp2 - i * e_x
260
    e_y = temp3 / np.linalg.norm(temp3)
261
262
    j = np.dot(e_y, temp2)
263
    x = (r1 * r1 - r2 * r2 + d * d) / (2 * d)
264
    y = (r1 * r1 - r3 * r3 - 2 * i * x + i * i + j * j) / (2 * j)
265
    temp4 = r1 * r1 - x * x - y * y
266
267
    if temp4 < 0:
268
        return None
269
270
    e_z = np.cross(e_x, e_y)
271
    z = np.sqrt(temp4)
272
    p_12_a = p1 + x * e_x + y * e_y + z * e_z
273
    p_12_b = p1 + x * e_x + y * e_y - z * e_z
274
275
    if np.abs(np.linalg.norm(p_12_a) - abs_val) < prec:
276
        return p_12_a
277
    if np.abs(np.linalg.norm(p_12_b) - abs_val) < prec:
278
        return p_12_b
279
    return None
280
281
282
def _unique_within_tol(arr, prec):
283
    """Return only unique values in a vector within ``prec``."""
284
    return arr[~np.any(np.triu(np.abs(arr[:, None] - arr) < prec, 1), axis=0)]
285
286
287
def _remove_vals(vec, vals_to_remove, prec):
288
    """Remove specified values in vec, within ``prec``."""
289
    return vec[~np.any(np.abs(vec[:, None] - vals_to_remove) < prec, axis=-1)]
290
291
292
def _shared_vals(v1, v2, prec):
293
    """Return values shared between v1, v2 within ``prec``."""
294
    return v1[np.argwhere(np.abs(v1[:, None] - v2) < prec)[:, 0]]
295