amd.reconstruct   B
last analyzed

Complexity

Total Complexity 46

Size/Duplication

Total Lines 295
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 46
eloc 184
dl 0
loc 295
rs 8.72
c 0
b 0
f 0

9 Functions

Rating   Name   Duplication   Size   Complexity  
B _find_further_point() 0 23 7
B _find_second_point() 0 13 6
A _remove_vals() 0 3 1
A _shared_vals() 0 3 1
C reconstruct() 0 94 10
B _bilaterate() 0 30 7
C _trilaterate() 0 43 10
A _neighbor_set() 0 47 3
A _unique_within_tol() 0 3 1

How to fix   Complexity   

Complexity

Complex classes like amd.reconstruct often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Functions for resconstructing a periodic set up to isometry from its
2
PDD. This is possible 'in a general position', see our papers for more.
3
"""
4
5
from itertools import combinations, permutations, product
6
7
import numpy as np
8
import numba
9
from scipy.spatial.distance import cdist
10
from scipy.spatial import KDTree
11
12
from ._nearest_neighbors import generate_concentric_cloud
13
from .utils import diameter
14
15
from ._types import FloatArray
16
17
__all__ = ["reconstruct"]
18
19
20
def reconstruct(pdd: FloatArray, cell: FloatArray) -> FloatArray:
21
    """Reconstruct a motif from a PDD and unit cell. This function will
22
    only work if ``pdd`` has enough columns, such that the last column
23
    has all values larger than 2 times the diameter of the unit cell. It
24
    also expects an uncollapsed PDD with no weights column. Do not use
25
    ``amd.PDD`` to compute the PDD for this function, instead use
26
    ``amd.PDD_reconstructable`` which returns a version of the PDD which
27
    is passable to this function. This function is quite slow and run
28
    time may vary a lot arbitrarily depending on input.
29
30
    Parameters
31
    ----------
32
    pdd : :class:`numpy.ndarray`
33
        The PDD of the periodic set to reconstruct. Needs `k` at least
34
        large enough so all values in the last column of pdd are greater
35
        than :code:`2 * diameter(cell)`, and needs to be uncollapsed
36
        without weights. Use amd.PDD_reconstructable to get a PDD which
37
        is acceptable for this argument.
38
    cell : :class:`numpy.ndarray`
39
        Unit cell of the periodic set to reconstruct.
40
41
    Returns
42
    -------
43
    :class:`numpy.ndarray`
44
        The reconstructed motif of the periodic set.
45
    """
46
47
    # TODO: get a more reduced neighbor set
48
    # TODO: find all shared distances in a big operation at the start
49
    # TODO: move PREC variable to its proper place
50
    PREC = 1e-10
51
52
    dims = cell.shape[0]
53
    if dims not in (2, 3):
54
        raise ValueError(
55
            "Reconstructing from PDD only implemented for 2 and 3 dimensions"
56
        )
57
    diam = diameter(cell)
58
    motif = [np.zeros((dims,))]
59
    if pdd.shape[0] == 1:
60
        return np.array(motif)
61
62
    # finding lattice distances so we can ignore them
63
    cloud_generator = iter(generate_concentric_cloud(np.array(motif), cell))
64
    next(cloud_generator)  # the origin
65
    cloud = []
66
    layer = next(cloud_generator)
67
    while np.any(np.linalg.norm(layer, axis=-1) <= diam):
68
        cloud.append(layer)
69
        layer = next(cloud_generator)
70
    cloud = np.concatenate(cloud)
71
72
    # is (a superset of) lattice points close enough to Voronoi domain
73
    nn_set = _neighbor_set(cell, PREC)
74
    lattice_dists = np.linalg.norm(cloud, axis=-1)
75
    lattice_dists = lattice_dists[lattice_dists <= diam]
76
    lattice_dists = _unique_within_tol(lattice_dists, PREC)
77
78
    # remove lattice distances from first and second rows
79
    row1_reduced = _remove_vals(pdd[0], lattice_dists, PREC)
80
    row2_reduced = _remove_vals(pdd[1], lattice_dists, PREC)
81
    # get shared dists between first and second rows
82
    shared_dists = _shared_vals(row1_reduced, row2_reduced, PREC)
83
    shared_dists = _unique_within_tol(shared_dists, PREC)
84
85
    # all combinations of vecs in neighbor set forming a basis
86
    bases = []
87
    for vecs in combinations(nn_set, dims):
88
        vecs = np.asarray(vecs)
89
        if np.abs(np.linalg.det(vecs)) > PREC:
90
            bases.extend(basis for basis in permutations(vecs, dims))
91
92
    q = _find_second_point(shared_dists, bases, cloud, PREC)
93
    if q is None:
94
        raise RuntimeError("Second point of motif could not be found.")
95
    motif.append(q)
96
97
    if pdd.shape[0] == 2:
98
        return np.array(motif)
99
100
    for row in pdd[2:, :]:
101
        row_reduced = _remove_vals(row, lattice_dists, PREC)
102
        shared_dists1 = _shared_vals(row1_reduced, row_reduced, PREC)
103
        shared_dists2 = _shared_vals(row2_reduced, row_reduced, PREC)
104
        shared_dists1 = _unique_within_tol(shared_dists1, PREC)
105
        shared_dists2 = _unique_within_tol(shared_dists2, PREC)
106
        q_ = _find_further_point(shared_dists1, shared_dists2, bases, cloud, q, PREC)
107
        if q_ is None:
108
            raise RuntimeError("Further point of motif could not be found.")
109
        motif.append(q_)
110
111
    motif = np.array(motif)
112
    motif = np.mod(motif @ np.linalg.inv(cell), 1) @ cell
113
    return motif
114
115
116
def _find_second_point(shared_dists, bases, cloud, prec):
117
    dims = cloud.shape[-1]
118
    abs_q = shared_dists[0]
119
    sphere_intersect_func = _trilaterate if dims == 3 else _bilaterate
120
121
    for distance_tup in combinations(shared_dists[1:], dims):
122
        for basis in bases:
123
            res = sphere_intersect_func(*basis, *distance_tup, abs_q, prec)
124
            if res is None:
125
                continue
126
            cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
127
            if np.all(cloud_res_dists - abs_q + prec > 0):
128
                return res
129
130
131
def _find_further_point(shared_dists1, shared_dists2, bases, cloud, q, prec):
132
    # distance from origin (first motif point) to further point
133
    dims = cloud.shape[-1]
134
    abs_q_ = shared_dists1[0]
135
136
    # try all ordered subsequences of distances shared between first and
137
    # further row, with all combinations of the vectors in the neighbor set
138
    # forming a basis, see if spheres centered at the vectors with the shared
139
    # distances as radii intersect at 4 (3 dims) points.
140
    sphere_intersect_func = _trilaterate if dims == 3 else _bilaterate
141
    for distance_tup in combinations(shared_dists1[1:], dims):
142
        for basis in bases:
143
            res = sphere_intersect_func(*basis, *distance_tup, abs_q_, prec)
144
            if res is None:
145
                continue
146
            # check point is in the voronoi domain
147
            cloud_res_dists = np.linalg.norm(cloud - res, axis=-1)
148
            if not np.all(cloud_res_dists - abs_q_ + prec > 0):
149
                continue
150
            # check |p - point| is among the row's shared distances
151
            dist_diff = np.abs(shared_dists2 - np.linalg.norm(q - res))
152
            if np.any(dist_diff < prec):
153
                return res
154
155
156
def _neighbor_set(cell, prec):
157
    """(A superset of) the neighbor set of origin for a lattice."""
158
159
    k_ = 5
160
    coeffs = np.array(list(product((-1, 0, 1), repeat=cell.shape[0])))
161
    coeffs = coeffs[coeffs.any(axis=-1)]  # remove (0,0,0)
162
163
    # half of all combinations of basis vectors
164
    vecs = []
165
    for c in coeffs:
166
        vecs.append(np.sum(cell * c[None, :].T, axis=0) / 2)
167
    vecs = np.array(vecs)
168
169
    origin = np.zeros((1, cell.shape[0]))
170
    cloud_generator = iter(generate_concentric_cloud(origin, cell))
171
    cloud = np.concatenate((next(cloud_generator), next(cloud_generator)))
172
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
173
    dists, inds = tree.query(vecs, k=k_)
174
    dists_ = np.empty_like(dists)
175
176
    while not np.allclose(dists, dists_, atol=0, rtol=1e-12):
177
        dists = dists_
178
        cloud = np.vstack((cloud, next(cloud_generator), next(cloud_generator)))
179
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
180
        dists_, inds = tree.query(vecs, k=k_)
181
182
    tmp_inds = np.unique(inds[:, 1:].flatten())
183
    tmp_inds = tmp_inds[tmp_inds != 0]
184
    neighbor_set = cloud[tmp_inds]
185
186
    # reduce neighbor set
187
    # half the lattice points and find their nearest neighbors in the lattice
188
    neighbor_set_half = neighbor_set / 2
189
    # for each of these vectors, check if 0 is a nearest neighbor.
190
    # so, check if the dist to 0 is leq (within tol) than dist to all other
191
    # lattice points.
192
    nn_norms = np.linalg.norm(neighbor_set, axis=-1)
193
    halves_norms = nn_norms / 2
194
    halves_to_lattice_dists = cdist(neighbor_set_half, neighbor_set)
195
196
    # Do I need to + PREC?
197
    # inds of voronoi neighbors in cloud
198
    voronoi_neighbors = np.all(
199
        halves_to_lattice_dists - halves_norms + prec >= 0, axis=-1
200
    )
201
    neighbor_set = neighbor_set[voronoi_neighbors]
202
    return neighbor_set
203
204
205
@numba.njit(cache=True, fastmath=True)
206
def _bilaterate(p1, p2, r1, r2, abs_val, prec):
207
    """Return the intersection of three circles."""
208
209
    d = np.sqrt((p2[0] - p1[0]) ** 2 + (p2[1] - p1[1]) ** 2)
210
    v = (p2 - p1) / d
211
212
    if d > r1 + r2:
213
        return None
214
    if d < abs(r1 - r2):
215
        return None
216
    if d == 0 and r1 == r2:
217
        return None
218
219
    a = (r1**2 - r2**2 + d**2) / (2 * d)
220
    h = np.sqrt(r1**2 - a**2)
221
    x2 = p1[0] + a * v[0]
222
    y2 = p1[1] + a * v[1]
223
    x3 = x2 + h * v[1]
224
    y3 = y2 - h * v[0]
225
    x4 = x2 - h * v[1]
226
    y4 = y2 + h * v[0]
227
    q1 = np.array((x3, y3))
228
    q2 = np.array((x4, y4))
229
230
    if np.abs(np.sqrt(x3**2 + y3**2) - abs_val) < prec:
231
        return q1
232
    if np.abs(np.sqrt(x4**2 + y4**2) - abs_val) < prec:
233
        return q2
234
    return None
235
236
237
@numba.njit(cache=True)
238
def _trilaterate(p1, p2, p3, r1, r2, r3, abs_val, prec):
239
    """Return the intersection of four spheres."""
240
241
    if np.linalg.norm(p1) > abs_val + r1 - prec:
242
        return None
243
    if np.linalg.norm(p2) > abs_val + r2 - prec:
244
        return None
245
    if np.linalg.norm(p3) > abs_val + r3 - prec:
246
        return None
247
    if np.linalg.norm(p1 - p2) > r1 + r2 - prec:
248
        return None
249
    if np.linalg.norm(p1 - p3) > r1 + r3 - prec:
250
        return None
251
    if np.linalg.norm(p2 - p3) > r2 + r3 - prec:
252
        return None
253
254
    temp1 = p2 - p1
255
    d = np.linalg.norm(temp1)
256
    e_x = temp1 / d
257
    temp2 = p3 - p1
258
    i = np.dot(e_x, temp2)
259
    temp3 = temp2 - i * e_x
260
    e_y = temp3 / np.linalg.norm(temp3)
261
262
    j = np.dot(e_y, temp2)
263
    x = (r1 * r1 - r2 * r2 + d * d) / (2 * d)
264
    y = (r1 * r1 - r3 * r3 - 2 * i * x + i * i + j * j) / (2 * j)
265
    temp4 = r1 * r1 - x * x - y * y
266
267
    if temp4 < 0:
268
        return None
269
270
    e_z = np.cross(e_x, e_y)
271
    z = np.sqrt(temp4)
272
    p_12_a = p1 + x * e_x + y * e_y + z * e_z
273
    p_12_b = p1 + x * e_x + y * e_y - z * e_z
274
275
    if np.abs(np.linalg.norm(p_12_a) - abs_val) < prec:
276
        return p_12_a
277
    if np.abs(np.linalg.norm(p_12_b) - abs_val) < prec:
278
        return p_12_b
279
    return None
280
281
282
def _unique_within_tol(arr, prec):
283
    """Return only unique values in a vector within ``prec``."""
284
    return arr[~np.any(np.triu(np.abs(arr[:, None] - arr) < prec, 1), axis=0)]
285
286
287
def _remove_vals(vec, vals_to_remove, prec):
288
    """Remove specified values in vec, within ``prec``."""
289
    return vec[~np.any(np.abs(vec[:, None] - vals_to_remove) < prec, axis=-1)]
290
291
292
def _shared_vals(v1, v2, prec):
293
    """Return values shared between v1, v2 within ``prec``."""
294
    return v1[np.argwhere(np.abs(v1[:, None] - v2) < prec)[:, 0]]
295