Completed
Push — master ( 582254...17fb6a )
by Wouter
03:58
created

transfer_component_analysis()   A

Complexity

Conditions 4

Size

Total Lines 55

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 12
CRAP Score 4.5923

Importance

Changes 0
Metric Value
cc 4
c 0
b 0
f 0
dl 0
loc 55
ccs 12
cts 18
cp 0.6667
crap 4.5923
rs 9.078

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4 1
import numpy as np
5 1
import scipy.stats as st
6 1
from scipy.sparse.linalg import eigs
7 1
from scipy.spatial.distance import cdist
8 1
import sklearn as sk
9 1
from sklearn.svm import LinearSVC
10 1
from sklearn.linear_model import LogisticRegression, LinearRegression
11 1
from sklearn.model_selection import cross_val_predict
12 1
from os.path import basename
13
14 1
from .util import is_pos_def
15
16
17 1
class TransferComponentClassifier(object):
18
    """
19
    Class of classifiers based on Transfer Component Analysis.
20
21
    Methods contain component analysis and general utilities.
22
    """
23
24 1 View Code Duplication
    def __init__(self, loss='logistic', l2=1.0, mu=1.0, num_components=1,
0 ignored issues
show
Duplication introduced
This code seems to be duplicated in your project.
Loading history...
25
                 kernel_type='rbf', bandwidth=1.0, order=2.0):
26
        """
27
        Select a particular type of transfer component classifier.
28
29
        INPUT   (1) str 'loss': loss function for weighted classifier, options:
30
                    'logistic', 'quadratic', 'hinge' (def: 'logistic')
31
                (2) float 'l2': l2-regularization parameter value (def:0.01)
32
                (3) float 'mu': trade-off parameter (def: 1.0)
33
                (4) int 'num_components': number of transfer components to
34
                    maintain (def: 1)
35
                (5) str 'kernel_type': type of kernel to use, options: 'rbf'
36
                    (def: 'rbf')
37
                (6) float 'bandwidth': kernel bandwidth for transfer component
38
                    analysis (def: 1.0)
39
                (7) float 'order': order of polynomial for kernel (def: 2.0)
40
        """
41 1
        self.loss = loss
42 1
        self.l2 = l2
43 1
        self.mu = mu
44 1
        self.num_components = num_components
45
46 1
        self.kernel_type = kernel_type
47 1
        self.bandwidth = bandwidth
48 1
        self.order = order
49
50
        # Initialize untrained classifiers
51 1
        if self.loss == 'logistic':
52
            # Logistic regression model
53 1
            self.clf = LogisticRegression()
54
        elif self.loss == 'quadratic':
55
            # Least-squares model
56
            self.clf = LinearRegression()
57
        elif self.loss == 'hinge':
58
            # Linear support vector machine
59
            self.clf = LinearSVC()
60
        else:
61
            # Other loss functions are not implemented
62
            raise NotImplementedError
63
64
        # Maintain source and transfer data for computing kernels
65 1
        self.XZ = ''
66
67
        # Maintain transfer components
68 1
        self.C = ''
69
70
        # Whether model has been trained
71 1
        self.is_trained = False
72
73
        # Dimensionality of training data
74 1
        self.train_data_dim = ''
75
76 1
    def kernel(self, X, Z, type='rbf', order=2, bandwidth=1.0):
77
        """
78
        Compute kernel for given data set.
79
80
        INPUT   (1) array 'X': data set (N samples by D features)
81
                (2) array 'Z': data set (M samples by D features)
82
                (3) str 'type': type of kernel, options: 'linear',
83
                    'polynomial', 'rbf', 'sigmoid' (def: 'linear')
84
                (4) float 'order': order of polynomial to use for the
85
                    polynomial kernel (def: 2.0)
86
                (5) float 'bandwidth': kernel bandwidth (def: 1.0)
87
        OUTPUT  (1) array: kernel matrix (N+M by N+M)
88
        """
89
        # Data shapes
90 1
        N, DX = X.shape
91 1
        M, DZ = Z.shape
92
93
        # Assert equivalent dimensionalities
94 1
        assert DX == DZ
95
96
        # Select type of kernel to compute
97 1
        if type == 'linear':
98
            # Linear kernel is data outer product
99
            return np.dot(X, Z.T)
100 1
        elif type == 'polynomial':
101
            # Polynomial kernel is an exponentiated data outer product
102
            return (np.dot(X, Z.T) + 1)**p
103 1
        elif type == 'rbf':
104
            # Radial basis function kernel
105 1
            return np.exp(-cdist(X, Z) / (2.*bandwidth**2))
106
        elif type == 'sigmoid':
107
            # Sigmoidal kernel
108
            return 1./(1 + np.exp(np.dot(X, Z.T)))
109
        else:
110
            raise NotImplementedError
111
112 1
    def transfer_component_analysis(self, X, Z):
113
        """
114
        Transfer Component Analysis.
115
116
        INPUT   (1) array 'X': source data set (N samples by D features)
117
                (2) array 'Z': target data set (M samples by D features)
118
        OUTPUT  (1) array 'C': transfer components (D features
119
                    by num_components)
120
                (2) array 'K': source and target data kernel distances
121
        """
122
        # Data shapes
123 1
        N, DX = X.shape
124 1
        M, DZ = Z.shape
125
126
        # Assert equivalent dimensionalities
127 1
        assert DX == DZ
128
129
        # Compute kernel matrix
130 1
        XZ = np.concatenate((X, Z), axis=0)
131 1
        K = self.kernel(XZ, XZ, type=self.kernel_type,
132
                        bandwidth=self.bandwidth)
133
134
        # Ensure positive-definiteness
135 1
        if not is_pos_def(K):
136
            print('Warning: covariate matrices not PSD.')
137
138
            regct = -6
139
            while not is_pos_def(K):
140
                print('Adding regularization: ' + str(10**regct))
141
142
                # Add regularization
143
                K += np.eye(N + M)*10.**regct
144
145
                # Increment regularization counter
146
                regct += 1
147
148
        # Normalization matrix
149 1
        L = np.vstack((np.hstack((np.ones((N, N))/N**2,
150
                                  -1*np.ones((N, M))/(N*M))),
151
                       np.hstack((-1*np.ones((M, N))/(N*M),
152
                                  np.ones((M, M))/M**2))))
153
154
        # Centering matrix
155 1
        H = np.eye(N + M) - np.ones((N + M, N + M)) / float(N + M)
156
157
        # Matrix Lagrangian objective function: (I + mu*K*L*K)^{-1}*K*H*K
158 1
        J = np.dot(np.linalg.inv(np.eye(N + M) +
159
                   self.mu*np.dot(np.dot(K, L), K)),
160
                   np.dot(np.dot(K, H), K))
161
162
        # Eigenvector decomposition as solution to trace minimization
163 1
        _, C = eigs(J, k=self.num_components)
164
165
        # Discard imaginary numbers (possible computation issue)
166 1
        return np.real(C), K
167
168 1
    def fit(self, X, y, Z):
169
        """
170
        Fit/train a classifier on data mapped onto transfer components.
171
172
        INPUT   (1) array 'X': source data (N samples by D features)
173
                (2) array 'y': source labels (N samples by 1)
174
                (3) array 'Z': target data (M samples by D features)
175
        OUTPUT
176
        """
177
        # Data shapes
178 1
        N, DX = X.shape
179 1
        M, DZ = Z.shape
180
181
        # Assert equivalent dimensionalities
182 1
        assert DX == DZ
183
184
        # Assert correct number of components for given dataset
185 1
        assert self.num_components <= N + M - 1
186
187
        # Maintain source and target data for later kernel computations
188 1
        self.XZ = np.concatenate((X, Z), axis=0)
189
190
        # Transfer component analysis
191 1
        self.C, K = self.transfer_component_analysis(X, Z)
192
193
        # Map source data onto transfer components
194 1
        X = np.dot(K[:N, :], self.C)
195
196
        # Train a weighted classifier
197 1
        if self.loss == 'logistic':
198
            # Logistic regression model with sample weights
199 1
            self.clf.fit(X, y)
200
        elif self.loss == 'quadratic':
201
            # Least-squares model with sample weights
202
            self.clf.fit(X, y)
203
        elif self.loss == 'hinge':
204
            # Linear support vector machine with sample weights
205
            self.clf.fit(X, y)
206
        else:
207
            # Other loss functions are not implemented
208
            raise NotImplementedError
209
210
        # Mark classifier as trained
211 1
        self.is_trained = True
212
213
        # Store training data dimensionality
214 1
        self.train_data_dim = DX
215
216 1
    def predict(self, Z_):
217
        """
218
        Make predictions on new dataset.
219
220
        INPUT   (1) array 'Z_': new data set (M samples by D features)
221
        OUTPUT  (2) array 'preds': label predictions (M samples by 1)
222
        """
223
        # Data shape
224 1
        M, D = Z_.shape
225
226
        # If classifier is trained, check for same dimensionality
227 1
        if self.is_trained:
228 1
            assert self.train_data_dim == D
229
230
        # Compute kernel for new data
231 1
        K = self.kernel(Z_, self.XZ, type=self.kernel_type,
232
                        bandwidth=self.bandwidth, order=self.order)
233
234
        # Map new data onto transfer components
235 1
        Z_ = np.dot(K, self.C)
236
237
        # Call scikit's predict function
238 1
        preds = self.clf.predict(Z_)
239
240
        # For quadratic loss function, correct predictions
241 1
        if self.loss == 'quadratic':
242
            preds = (np.sign(preds)+1)/2.
243
244
        # Return predictions array
245 1
        return preds
246
247 1
    def get_params(self):
248
        """Get classifier parameters."""
249
        return self.clf.get_params()
250
251 1
    def is_trained(self):
252
        """Check whether classifier is trained."""
253
        return self.is_trained
254