| 1 |  |  | """General utility functions. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from typing import Tuple | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | import numba | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | from scipy.spatial.distance import squareform | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | def diameter(cell): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     """Diameter of a unit cell (as a square matrix in | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |     Cartesian/Orthogonal form) in 3 or fewer dimensions. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     dims = cell.shape[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |     if dims == 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |         return cell[0][0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     if dims == 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |         diagonals = np.array([cell[0] + cell[1], cell[0] - cell[1]]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |         d = np.amax(np.linalg.norm(diagonals, axis=-1)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     elif dims == 3: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         diams = np.array([ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |             cell[0] + cell[1] + cell[2], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |             cell[0] + cell[1] - cell[2], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |             cell[0] - cell[1] + cell[2], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |             - cell[0] + cell[1] + cell[2] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |         ]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |         d = np.amax(np.linalg.norm(diams, axis=-1)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |     else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         msg = 'diameter() not implemented for dims > 3 (passed cell shape ' \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |               f'{cell.shape}).' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         raise NotImplementedError(msg) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |     return d | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  | @numba.njit() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  | def cellpar_to_cell(a, b, c, alpha, beta, gamma): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |     """Simplified version of function from :mod:`ase.geometry` of the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |     same name. 3D unit cell parameters a,b,c,α,β,γ --> cell as 3x3 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     :class:`numpy.ndarray`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     eps = 2 * np.spacing(90)  # ~1.4e-14 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |     cos_alpha = 0 if abs(abs(alpha) - 90) < eps else np.cos(alpha * np.pi / 180) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |     cos_beta = 0 if abs(abs(beta) - 90) < eps else np.cos(beta * np.pi / 180) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |     cos_gamma = 0 if abs(abs(gamma) - 90) < eps else np.cos(gamma * np.pi / 180) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |     if abs(gamma - 90) < eps: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         sin_gamma = 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |     elif abs(gamma + 90) < eps: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         sin_gamma = -1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |     else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |         sin_gamma = np.sin(gamma * np.pi / 180.) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |     cy = (cos_alpha - cos_beta * cos_gamma) / sin_gamma | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |     cz_sqr = 1 - cos_beta ** 2 - cy ** 2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     if cz_sqr < 0: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |         raise RuntimeError('Could not create unit cell from given parameters.') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |     cell = np.zeros((3, 3)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |     cell[0, 0] = a | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |     cell[1, 0] = b * cos_gamma | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |     cell[1, 1] = b * sin_gamma | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |     cell[2, 0] = c * cos_beta | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |     cell[2, 1] = c * cy | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |     cell[2, 2] = c * np.sqrt(cz_sqr) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |     return cell | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  | @numba.njit() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  | def cellpar_to_cell_2D(a, b, alpha): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |     """2D unit cell parameters a,b,α --> cell as 2x2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |     :class:`numpy.ndarray`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |     cell = np.zeros((2, 2)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |     ang = alpha * np.pi / 180. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |     cell[0, 0] = a | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |     cell[1, 0] = b * np.cos(ang) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |     cell[1, 1] = b * np.sin(ang) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |     return cell | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 87 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 88 |  |  | def cell_to_cellpar(cell): | 
            
                                                                        
                            
            
                                    
            
            
                | 89 |  |  |     """Unit cell as a 3x3 :class:`numpy.ndarray` --> list of 6 lengths + | 
            
                                                                        
                            
            
                                    
            
            
                | 90 |  |  |     angles. | 
            
                                                                        
                            
            
                                    
            
            
                | 91 |  |  |     """ | 
            
                                                                        
                            
            
                                    
            
            
                | 92 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 93 |  |  |     lengths = np.linalg.norm(cell, axis=-1) | 
            
                                                                        
                            
            
                                    
            
            
                | 94 |  |  |     angles = [] | 
            
                                                                        
                            
            
                                    
            
            
                | 95 |  |  |     for i, j in [(1, 2), (0, 2), (0, 1)]: | 
            
                                                                        
                            
            
                                    
            
            
                | 96 |  |  |         ang_rad = np.arccos(np.dot(cell[i], cell[j]) / (lengths[i] * lengths[j])) | 
            
                                                                        
                            
            
                                    
            
            
                | 97 |  |  |         angles.append(np.rad2deg(ang_rad)) | 
            
                                                                        
                            
            
                                    
            
            
                | 98 |  |  |     return np.concatenate((lengths, np.array(angles))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  | def cell_to_cellpar_2D(cell): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |     """Unit cell as a 2x2 :class:`numpy.ndarray` --> list of 2 lengths | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |     and an angle. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |     cellpar = np.zeros((3, )) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |     lengths = np.linalg.norm(cell, axis=-1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |     ang_rad = np.arccos(np.dot(cell[0], cell[1]) / (lengths[0] * lengths[1])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |     cellpar[0] = lengths[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |     cellpar[1] = lengths[1] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |     cellpar[2] = np.rad2deg(ang_rad) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |     return cellpar | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  | def neighbours_from_distance_matrix( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |         n: int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |         dm: np.ndarray | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  | ) -> Tuple[np.ndarray, np.ndarray]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |     """Given a distance matrix, find the n nearest neighbours of each | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |     item. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |     Parameters | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |     ---------- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |     n : int | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         Number of nearest neighbours to find for each item. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |     dm : :class:`numpy.ndarray` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         2D distance matrix or 1D condensed distance matrix. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |     Returns | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |     ------- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |     (nn_dm, inds) : Tuple[:class:`numpy.ndarray`, :class:`numpy.ndarray`] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |         ``nn_dm[i][j]`` is the distance from item :math:`i` to its | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         :math:`j+1` st nearest neighbour, and ``inds[i][j]`` is the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |         index of this neighbour (:math:`j+1` since index 0 is the first | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         nearest neighbour). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |     inds = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |     # 2D distance matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |     if len(dm.shape) == 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |         inds = np.array([np.argpartition(row, n)[:n] for row in dm]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |     # 1D condensed distance vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |     elif len(dm.shape) == 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |         dm = squareform(dm) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |         inds = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |         for i, row in enumerate(dm): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |             inds_row = np.argpartition(row, n+1)[:n+1] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |             inds_row = inds_row[inds_row != i][:n] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |             inds.append(inds_row) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |         inds = np.array(inds) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |     else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         msg = "Input must be a NumPy ndarray, either a 2D distance matrix " \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |               "or a condensed distance matrix (returned by SciPy's pdist)." | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |         ValueError(msg) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |     # inds are the indexes of nns: inds[i,j] is the j-th nn to point i | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |     nn_dm = np.take_along_axis(dm, inds, axis=-1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |     sorted_inds = np.argsort(nn_dm, axis=-1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |     inds = np.take_along_axis(inds, sorted_inds, axis=-1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |     nn_dm = np.take_along_axis(nn_dm, sorted_inds, axis=-1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |     return nn_dm, inds | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  | def random_cell(length_bounds=(1, 2), angle_bounds=(60, 120), dims=3): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |     """Dimensions 2 and 3 only. Random unit cell with uniformally chosen | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |     length and angle parameters between bounds. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |     ll, lu = length_bounds | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |     al, au = angle_bounds | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |     if dims == 3: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |         while True: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |             lengths = [np.random.uniform(low=ll, high=lu) for _ in range(dims)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |             angles = [np.random.uniform(low=al, high=au) for _ in range(dims)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |             try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |                 cell = cellpar_to_cell(*lengths, *angles) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |                 break | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |             except RuntimeError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |     elif dims == 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |         lengths = [np.random.uniform(low=ll, high=lu) for _ in range(dims)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |         alpha = np.random.uniform(low=al, high=au) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |         cell = cellpar_to_cell_2D(*lengths, alpha) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |     else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |         msg = 'random_cell only implimented for dimensions 2 and 3 (passed ' \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |               'f{dims})' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |         raise NotImplementedError(msg) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 196 |  |  |     return cell | 
            
                                                        
            
                                    
            
            
                | 197 |  |  |  |