Completed
Push — master ( fe60a9...f3d068 )
by Wouter
04:03
created

regularize_matrix()   B

Complexity

Conditions 6

Size

Total Lines 45

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 13
CRAP Score 6.4689

Importance

Changes 0
Metric Value
cc 6
dl 0
loc 45
ccs 13
cts 17
cp 0.7647
crap 6.4689
rs 7.5384
c 0
b 0
f 0
1
"""
2
Utility functions necessary for different classifiers.
3
4
Contains algebraic operations, and label encodings.
5
"""
6
7 1
import numpy as np
8 1
import numpy.linalg as al
9 1
import scipy.stats as st
10
11
12 1
def one_hot(y, fill_k=False, one_not=False):
13
    """Map to one-hot encoding."""
14
    # Check labels
15 1
    labels = np.unique(y)
16
17
    # Number of classes
18 1
    K = len(labels)
19
20
    # Number of samples
21 1
    N = y.shape[0]
22
23
    # Preallocate array
24 1
    if one_not:
25 1
        Y = -np.ones((N, K))
26
    else:
27 1
        Y = np.zeros((N, K))
28
29
    # Set k-th column to 1 for n-th sample
30 1
    for n in range(N):
31
32
        # Map current class to index label
33 1
        y_n = (y[n] == labels)
34
35 1
        if fill_k:
36
            Y[n, y_n] = y_n
37
        else:
38 1
            Y[n, y_n] = 1
39
40 1
    return Y, labels
41
42
43 1
def regularize_matrix(A, a=0.0):
44
    """
45
    Regularize matrix by ensuring minimum eigenvalues.
46
47
    INPUT   (1) array 'A': square matrix
48
            (2) float 'a': constraint on minimum eigenvalue
49
    OUTPUT  (1) array 'B': constrained matrix
50
    """
51
    # Check for square matrix
52 1
    N, M = A.shape
53 1
    if not N == M:
54
        raise ValueError('Matrix not square.')
55
56
    # Check for valid matrix entries
57 1
    if np.any(np.isnan(A)) or np.any(np.isinf(A)):
58
        raise ValueError('Matrix contains NaNs or infinities.')
59
60
    # Check for non-negative minimum eigenvalue
61 1
    if a < 0:
62
        raise ValueError('minimum eigenvalue cannot be negative.')
63
64 1
    elif a == 0:
65
        return A
66
67
    else:
68
        # Ensure symmetric matrix
69 1
        A = (A + A.T) / 2
70
71
        # Eigenvalue decomposition
72 1
        E, V = al.eig(A)
73
74
        # Regularization matrix
75 1
        aI = a * np.eye(N)
76
77
        # Subtract regularization
78 1
        E = np.diag(E) + aI
79
80
        # Cap negative eigenvalues at zero
81 1
        E = np.maximum(0, E)
82
83
        # Reconstruct matrix
84 1
        B = np.dot(np.dot(V, E), V.T)
85
86
        # Add back subtracted regularization
87 1
        return B + aI
88
89
90 1
def is_pos_def(X):
91
    """Check for positive definiteness."""
92 1
    return np.all(np.linalg.eigvals(X) > 0)
93
94
95 1
def nullspace(A, atol=1e-13, rtol=0):
96
    """
97
    Compute an approximate basis for the nullspace of A.
98
99
    INPUT   (1) array 'A': 1-D array with length k will be treated
100
                as a 2-D with shape (1, k).
101
            (2) float 'atol': the absolute tolerance for a zero singular value.
102
                Singular values smaller than `atol` are considered to be zero.
103
            (3) float 'rtol': relative tolerance. Singular values less than
104
                rtol*smax are considered to be zero, where smax is the largest
105
                singular value.
106
107
                If both `atol` and `rtol` are positive, the combined tolerance
108
                is the maximum of the two; tol = max(atol, rtol * smax)
109
                Singular values smaller than `tol` are considered to be zero.
110
    OUTPUT  (1) array 'B': if A is an array with shape (m, k), then B will be
111
                an array with shape (k, n), where n is the estimated dimension
112
                of the nullspace of A.  The columns of B are a basis for the
113
                nullspace; each element in np.dot(A, B) will be
114
                approximately zero.
115
    """
116
    # Expand A to a matrix
117
    A = np.atleast_2d(A)
118
119
    # Singular value decomposition
120
    u, s, vh = al.svd(A)
121
122
    # Set tolerance
123
    tol = max(atol, rtol * s[0])
124
125
    # Compute the number of non-zero entries
126
    nnz = (s >= tol).sum()
127
128
    # Conjugate and transpose to ensure real numbers
129
    ns = vh[nnz:].conj().T
130
131
    return ns
132