Passed
Push — master ( 0f880f...d6a4c8 )
by Daniel
01:47
created

amd.utils._extend_signature()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Helpful utility functions, e.g. unit cell diameter, converting
2
cell parameters to Cartesian form, and an ETA class."""
3
4
from typing import Tuple
5
import time
6
import datetime
7
8
import scipy.spatial
9
import numpy as np
10
11
12
def diameter(cell):
13
    """Diameter of a unit cell in 3 or fewer dimensions."""
14
    dims = cell.shape[0]
15
    if dims == 1:
16
        return cell[0][0]
17
    if dims == 2:
18
        d = np.amax(np.linalg.norm(np.array([cell[0] + cell[1], cell[0] - cell[1]]), axis=-1))
19
    elif dims == 3:
20
        d = np.amax(np.array([
21
            np.linalg.norm(cell[0] + cell[1] + cell[2]),
22
            np.linalg.norm(cell[0] + cell[1] - cell[2]),
23
            np.linalg.norm(cell[0] - cell[1] + cell[2]),
24
            np.linalg.norm(-cell[0] + cell[1] + cell[2])
25
        ]))
26
    else:
27
        raise ValueError(f'diameter only implimented for dimensions <= 3 (passed {dims})')
28
    return d
29
30
31
def cellpar_to_cell(a, b, c, alpha, beta, gamma):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
32
    """Simplified version of function from ase.geometry.
33
    3D unit cell parameters a,b,c,α,β,γ --> cell as 3x3 ndarray.
34
    """
35
    # Handle orthorhombic cells separately to avoid rounding errors
36
    eps = 2 * np.spacing(90.0, dtype=np.float64)  # around 1.4e-14
37
38
    cos_alpha = 0. if abs(abs(alpha) - 90.) < eps else np.cos(alpha * np.pi / 180.)
39
    cos_beta  = 0. if abs(abs(beta)  - 90.) < eps else np.cos(beta * np.pi / 180.)
0 ignored issues
show
Coding Style introduced by
Exactly one space required before assignment
Loading history...
40
    cos_gamma = 0. if abs(abs(gamma) - 90.) < eps else np.cos(gamma * np.pi / 180.)
41
42
    if abs(gamma - 90) < eps:
43
        sin_gamma = 1.
44
    elif abs(gamma + 90) < eps:
45
        sin_gamma = -1.
46
    else:
47
        sin_gamma = np.sin(gamma * np.pi / 180.)
48
49
    cy = (cos_alpha - cos_beta * cos_gamma) / sin_gamma
50
    cz_sqr = 1. - cos_beta ** 2 - cy ** 2
51
    if cz_sqr < 0:
52
        raise RuntimeError('Could not create unit cell from parameters ' + \
53
                           f'a={a},b={b},c={c},α={alpha},β={beta},γ={gamma}')
54
55
    return np.array([[a, 0, 0],
56
                     [b*cos_gamma, b*sin_gamma, 0],
57
                     [c*cos_beta, c*cy, c*np.sqrt(cz_sqr)]])
58
59
60
def cellpar_to_cell_2D(a, b, alpha):
61
    """UD unit cell parameters a,b,α --> cell as 2x2 ndarray."""
62
    cell = np.array([[a, 0],
63
                     [b * np.cos(alpha * np.pi / 180.), b * np.sin(alpha * np.pi / 180.)]])
64
    return cell
65
66
67
def neighbours_from_distance_matrix(
68
        n: int,
69
        dm: np.ndarray
70
) -> Tuple[np.ndarray, np.ndarray]:
71
    """Given a distance matrix, find the ``n`` nearest neighbours of each item.
72
73
    Parameters
74
    ----------
75
    n : int
76
        Number of nearest neighbours to find for each item.
77
    dm : ndarray
78
        2D distance matrix or 1D condensed distance matrix.
79
80
    Returns
81
    -------
82
    tuple of ndarrays (nn_dm, inds)
83
        For item ``i``, ``nn_dm[i][j]`` is the distance from item ``i`` to its ``j+1`` st
84
        nearest neighbour, and ``inds[i][j]`` is the index of this neighbour (``j+1`` since
85
        index 0 is the first nearest neighbour).
86
    """
87
88
    inds = None
89
90
    # 2D distance matrix
91
    if len(dm.shape) == 2:
92
        inds = np.array([np.argpartition(row, n)[:n] for row in dm])
93
94
    # 1D condensed distance vector
95
    elif len(dm.shape) == 1:
96
        dm = scipy.spatial.distance.squareform(dm)
97
        inds = []
98
        for i, row in enumerate(dm):
99
            inds_row = np.argpartition(row, n+1)[:n+1]
100
            inds_row = inds_row[inds_row != i][:n]
101
            inds.append(inds_row)
102
        inds = np.array(inds)
103
104
    else:
105
        ValueError(
106
            'Input must be an ndarray, either a 2D distance matrix '
107
            'or a condensed distance matrix (returned by pdist).')
108
109
    # inds are the indexes of nns: inds[i,j] is the j-th nn to point i
110
    nn_dm = np.take_along_axis(dm, inds, axis=-1)
111
    sorted_inds = np.argsort(nn_dm, axis=-1)
112
    inds = np.take_along_axis(inds, sorted_inds, axis=-1)
113
    nn_dm = np.take_along_axis(nn_dm, sorted_inds, axis=-1)
114
    return nn_dm, inds
115
116
117
def lattice_cubic(scale=1, dims=3):
118
    """Return a pair (motif, cell) representing a cubic lattice, passable to
119
    ``amd.AMD()`` or ``amd.PDD()``."""
120
    return (np.zeros((1, dims)), np.identity(dims) * scale)
121
122
123
def random_cell(length_bounds=(1, 2), angle_bounds=(60, 120), dims=3):
124
    """Random unit cell."""
125
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
126
    lengths = [np.random.uniform(low=length_bounds[0],
127
                                 high=length_bounds[1]) 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
128
               for _ in range(dims)]
129
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
130
    if dims == 3:
131
        angles = [np.random.uniform(low=angle_bounds[0],
132
                                    high=length_bounds[1]) 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
133
                  for _ in range(dims)]
134
        return cellpar_to_cell(*lengths, *angles)
135
136
    if dims == 2:
137
        alpha = np.random.uniform(low=angle_bounds[0],
138
                                   high=length_bounds[1])
0 ignored issues
show
Coding Style introduced by
Wrong continued indentation (remove 1 space).
Loading history...
139
        return cellpar_to_cell_2D(*lengths, alpha)
140
141
    raise ValueError(f'random_cell only implimented for dimensions 2 and 3 (passed {dims})')
142
143
144
class ETA:
145
    """Pass total amount to do, then call .update() on every loop.
146
    This object will estimate an ETA and print it to the terminal."""
147
148
    # epochtime_{n+1} = factor * epochtime + (1-factor) * epochtime_{n}
149
    _moving_average_factor = 0.3
150
151
    def __init__(self, to_do, update_rate=100):
152
        self.to_do = to_do
153
        self.update_rate = update_rate
154
        self.counter = 0
155
        self.start_time = time.perf_counter()
156
        self.tic = self.start_time
157
        self.time_per_epoch = None
158
        self.done = False
159
160
    def update(self):
161
        """Call when one item is finished."""
162
163
        self.counter += 1
164
165
        if self.counter == self.to_do:
166
            msg = self._finished()
167
            print(msg, end='\r\n')
168
            self.done = True
169
            return
170
171
        if self.counter > self.to_do:
172
            return
173
174
        if not self.counter % self.update_rate:
175
            msg = self._end_epoch()
176
            print(msg, end='\r')
177
178
    def _end_epoch(self):
179
        toc = time.perf_counter()
180
        epoch_time = toc - self.tic
181
        if self.time_per_epoch is None:
182
            self.time_per_epoch = epoch_time
183
        else:
184
            self.time_per_epoch = ETA._moving_average_factor * epoch_time + \
185
                                  (1 - ETA._moving_average_factor) * self.time_per_epoch
186
187
        percent = round(100 * self.counter / self.to_do, 2)
188
        percent = '{:.2f}'.format(percent)
189
        remaining = int(((self.to_do - self.counter) / self.update_rate) * self.time_per_epoch)
190
        eta = str(datetime.timedelta(seconds=remaining))
191
        self.tic = toc
192
        return f'{percent}%, ETA {eta}' + ' ' * 30
193
194
    def _finished(self):
195
        total = time.perf_counter() - self.start_time
196
        msg = f'Total time: {round(total, 2)}s, ' \
197
              f'n passes: {self.counter} ' \
198
              f'({round(self.to_do/total, 2)} passes/second)'
199
        return msg
200