| 1 |  |  | #!/usr/bin/env python
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | # -*- coding: utf-8 -*-
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 | 1 |  | import numpy as np
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 | 1 |  | import scipy.stats as st
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 | 1 |  | from scipy.sparse import linalg
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 | 1 |  | from scipy.optimize import minimize
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 | 1 |  | import sklearn as sk
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 | 1 |  | from sklearn.svm import LinearSVC
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 | 1 |  | from sklearn.linear_model import LogisticRegression, LinearRegression
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 | 1 |  | from sklearn.model_selection import cross_val_predict
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 | 1 |  | from os.path import basename
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 | 1 |  | from .util import is_pos_def
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 | 1 |  | class StructuralCorrespondenceClassifier(object):
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |     """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     Class of classifiers based on structural correspondence learning.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |     Methods contain different importance-weight estimators and different loss
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     functions.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |     """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  | 
 | 
            
                                                                                                            
                            
            
                                                                    
                                                                                                        
            
            
                | 25 | 1 | View Code Duplication |     def __init__(self, loss='logistic', l2=1.0, num_pivots=1,
 | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |                  num_components=1):
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |         Select a particular type of importance-weighted classifier.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |         INPUT   (1) str 'loss': loss function for weighted classifier, options:
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |                     'logistic', 'quadratic', 'hinge' (def: 'logistic')
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |                 (2) float 'l2': l2-regularization parameter value (def:0.01)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |                 (3) int 'num_pivots': number of pivot features to use (def: 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |                 (4) int 'num_components': number of components to use after
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |                     extracting pivot features (def: 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 | 1 |  |         self.loss = loss
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 | 1 |  |         self.l2 = l2
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 | 1 |  |         self.num_pivots = num_pivots
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 | 1 |  |         self.num_components = num_components
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         # Initialize untrained classifiers based on choice of loss function
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 | 1 |  |         if self.loss == 'logistic':
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |             # Logistic regression model
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 | 1 |  |             self.clf = LogisticRegression()
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         elif self.loss == 'quadratic':
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |             # Least-squares model
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |             self.clf = LinearRegression()
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |         elif self.loss == 'hinge':
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |             # Linear support vector machine
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |             self.clf = LinearSVC()
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |         else:
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |             # Other loss functions are not implemented
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |             raise NotImplementedError
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |         # Whether model has been trained
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 | 1 |  |         self.is_trained = False
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |         # Maintain pivot component matrix
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 | 1 |  |         self.C = 0
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         # Dimensionality of training data
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 | 1 |  |         self.train_data_dim = ''
 | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 64 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 65 | 1 |  |     def augment_features(self, X, Z, l2=0.0):
 | 
            
                                                                        
                            
            
                                    
            
            
                | 66 |  |  |         """
 | 
            
                                                                        
                            
            
                                    
            
            
                | 67 |  |  |         Find a set of pivot features, train predictors and extract bases.
 | 
            
                                                                        
                            
            
                                    
            
            
                | 68 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 69 |  |  |         INPUT   (1) array 'X': source data array (N samples by D features)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 70 |  |  |                 (2) array 'Z': target data array (M samples by D features)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 71 |  |  |         """
 | 
            
                                                                        
                            
            
                                    
            
            
                | 72 |  |  |         # Data shapes
 | 
            
                                                                        
                            
            
                                    
            
            
                | 73 |  |  |         N, DX = X.shape
 | 
            
                                                                        
                            
            
                                    
            
            
                | 74 |  |  |         M, DZ = Z.shape
 | 
            
                                                                        
                            
            
                                    
            
            
                | 75 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 76 |  |  |         # Assert equivalent dimensionalities
 | 
            
                                                                        
                            
            
                                    
            
            
                | 77 |  |  |         assert DX == DZ
 | 
            
                                                                        
                            
            
                                    
            
            
                | 78 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 79 |  |  |         # Concatenate source and target data
 | 
            
                                                                        
                            
            
                                    
            
            
                | 80 |  |  |         XZ = np.concatenate((X, Z), axis=0)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 81 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 82 |  |  |         # Sort indices based on frequency of features (assumes BoW encoding)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 83 |  |  |         ix = np.argsort(np.sum(XZ, axis=0))
 | 
            
                                                                        
                            
            
                                    
            
            
                | 84 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 85 |  |  |         # Keep most frequent features
 | 
            
                                                                        
                            
            
                                    
            
            
                | 86 |  |  |         ix = ix[::-1][:self.num_pivots]
 | 
            
                                                                        
                            
            
                                    
            
            
                | 87 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 88 |  |  |         # Slice out pivot features and relabel them as present(=1)/absent(=0)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 89 |  |  |         pivot = (XZ[:, ix] > 0).astype('float')
 | 
            
                                                                        
                            
            
                                    
            
            
                | 90 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 91 |  |  |         # Solve prediction tasks with a Huber loss function
 | 
            
                                                                        
                            
            
                                    
            
            
                | 92 |  |  |         P = np.zeros((DX, self.num_pivots))
 | 
            
                                                                        
                            
            
                                    
            
            
                | 93 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 94 |  |  |         # Loop over pivot features
 | 
            
                                                                        
                            
            
                                    
            
            
                | 95 |  |  |         for l in range(self.num_pivots):
 | 
            
                                                                        
                            
            
                                    
            
            
                | 96 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 97 |  |  |             # Setup loss function for single pivot
 | 
            
                                                                        
                            
            
                                    
            
            
                | 98 |  |  |             def L(theta): return self.Huber_loss(theta, XZ, pivot[:, l])
 | 
            
                                                                        
                            
            
                                    
            
            
                | 99 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 100 |  |  |             # Setup gradient function for single pivot
 | 
            
                                                                        
                            
            
                                    
            
            
                | 101 |  |  |             def J(theta): return self.Huber_grad(theta, XZ, pivot[:, l])
 | 
            
                                                                        
                            
            
                                    
            
            
                | 102 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 103 |  |  |             # Make pivot predictor with a Huber loss function
 | 
            
                                                                        
                            
            
                                    
            
            
                | 104 |  |  |             results = minimize(L, np.random.randn(DX, 1), jac=J, method='BFGS',
 | 
            
                                                                        
                            
            
                                    
            
            
                | 105 |  |  |                                options={'gtol': 1e-6, 'disp': True})
 | 
            
                                                                        
                            
            
                                    
            
            
                | 106 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 107 |  |  |             # Store optimal parameters
 | 
            
                                                                        
                            
            
                                    
            
            
                | 108 |  |  |             P[:, l] = results.x
 | 
            
                                                                        
                            
            
                                    
            
            
                | 109 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 110 |  |  |         # Compute covariance matrix of predictors
 | 
            
                                                                        
                            
            
                                    
            
            
                | 111 |  |  |         SP = np.cov(P)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 112 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 113 |  |  |         # Add regularization to ensure positive-definiteness
 | 
            
                                                                        
                            
            
                                    
            
            
                | 114 |  |  |         SP += l2*np.eye(self.num_pivots)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 115 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 116 |  |  |         # Eigenvalue decomposition of pivot predictor matrix
 | 
            
                                                                        
                            
            
                                    
            
            
                | 117 |  |  |         V, C = np.linalg.eig(SP)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 118 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 119 |  |  |         # Reduce number of components
 | 
            
                                                                        
                            
            
                                    
            
            
                | 120 |  |  |         C = C[:, :self.num_components]
 | 
            
                                                                        
                            
            
                                    
            
            
                | 121 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 122 |  |  |         # Augment features
 | 
            
                                                                        
                            
            
                                    
            
            
                | 123 |  |  |         Xa = np.concatenate((np.dot(X, C), X), axis=1)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 124 |  |  |         Za = np.concatenate((np.dot(Z, C), Z), axis=1)
 | 
            
                                                                        
                            
            
                                    
            
            
                | 125 |  |  | 
 | 
            
                                                                        
                            
            
                                    
            
            
                | 126 |  |  |         return Xa, Za, C
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 | 1 |  |     def Huber_loss(self, theta, X, y, l2=0.0):
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         Huber loss function.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |         Reference: Ando & Zhang (2005a). A framework for learning predictive
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         structures from multiple tasks and unlabeled data. JMLR.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         INPUT   (1) array 'theta': classifier parameters (D features by 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |                 (2) array 'X': data (N samples by D features)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |                 (3) array 'y': label vector (N samples by 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |                 (4) float 'l2': l2-regularization parameter (def= 0.0)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |         OUTPUT  (1) Loss/objective function value
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |                 (2) Gradient with respect to classifier parameters
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |         # Precompute terms
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |         Xy = (X.T*y.T).T
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |         Xyt = np.dot(Xy, theta)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |         # Indices of discontinuity
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |         ix = (Xyt >= -1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |         # Loss function
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |         return np.sum(np.clip(1 - Xyt[ix], 0, None)**2, axis=0) \
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |             + np.sum(-4*Xyt[~ix], axis=0) + l2*np.sum(theta**2, axis=0)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 | 1 |  |     def Huber_grad(self, theta, X, y, l2=0.0):
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         Huber gradient computation.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |         Reference: Ando & Zhang (2005a). A framework for learning predictive
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         structures from multiple tasks and unlabeled data. JMLR.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |         INPUT   (1) array 'theta': classifier parameters (D features by 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |                 (2) array 'X': data (N samples by D features)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |                 (3) array 'y': label vector (N samples by 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |                 (4) float 'l2': l2-regularization parameter (def= 0.0)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |         OUTPUT  (1) Loss/objective function value
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |                 (2) Gradient with respect to classifier parameters
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |         # Precompute terms
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |         Xy = (X.T*y.T).T
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |         Xyt = np.dot(Xy, theta)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |         # Indices of discontinuity
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |         ix = (Xyt >= -1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |         # Gradient
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |         return np.sum(2*np.clip(1-Xyt[ix], 0, None).T * -Xy[ix, :].T,
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |                       axis=1).T + np.sum(-4*Xy[~ix, :], axis=0) + 2*l2*theta
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  | 
 | 
            
                                                                                                            
                            
            
                                                                    
                                                                                                        
            
            
                | 178 | 1 | View Code Duplication |     def fit(self, X, y, Z):
 | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |         Fit/train an structural correpondence classifier.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |         INPUT   (1) array 'X': source data (N samples by D features)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |                 (2) array 'y': source labels (N samples by 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |                 (3) array 'Z': target data (M samples by D features)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |         OUTPUT  None
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |         # Data shapes
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |         N, DX = X.shape
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |         M, DZ = Z.shape
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |         # Assert equivalent dimensionalities
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |         assert DX == DZ
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |         # Augment features
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |         X, _, self.C = self.augment_features(X, Z, l2=self.l2)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |         # Train a classifier
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |         if self.loss == 'logistic':
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |             # Logistic regression model
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |             self.clf.fit(X, y)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |         elif self.loss == 'quadratic':
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |             # Least-squares model
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |             self.clf.fit(X, y)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |         elif self.loss == 'hinge':
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |             # Linear support vector machine
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |             self.clf.fit(X, y)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |         else:
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |             # Other loss functions are not implemented
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |             raise NotImplementedError
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |         # Mark classifier as trained
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |         self.is_trained = True
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |         # Store training data dimensionality
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |         self.train_data_dim = DX + self.num_components
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 | 1 |  |     def predict(self, Z_):
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |         Make predictions on new dataset.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |         INPUT   (1) array 'Z_': new data set (M samples by D features)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |         OUTPUT  (2) array 'preds': label predictions (M samples by 1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |         """
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |         # Data shape
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |         M, D = Z_.shape
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |         # If classifier is trained, check for same dimensionality
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |         if self.is_trained:
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |             assert self.train_data_dim == D or \
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |                    self.train_data_dim == D + self.num_components
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |         # Check for augmentation
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |         if not self.train_data_dim == D:
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |             Z_ = np.concatenate((np.dot(Z_, self.C), Z_), axis=1)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |         # Call scikit's predict function
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |         preds = self.clf.predict(Z_)
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |         # For quadratic loss function, correct predictions
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |         if self.loss == 'quadratic':
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |             preds = (np.sign(preds)+1)/2.
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |         # Return predictions array
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |         return preds
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 | 1 |  |     def get_params(self):
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  |         """Get classifier parameters."""
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |         return self.clf.get_params()
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  | 
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 | 1 |  |     def is_trained(self):
 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |         """Check whether classifier is trained."""
 | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 252 |  |  |         return self.is_trained
 | 
            
                                                        
            
                                    
            
            
                | 253 |  |  |  |