Completed
Branch master (f50597)
by Wouter
52s
created

TransferComponentClassifier   A

Complexity

Total Complexity 26

Size/Duplication

Total Lines 237
Duplicated Lines 21.52 %

Test Coverage

Coverage 0%

Importance

Changes 0
Metric Value
c 0
b 0
f 0
dl 51
loc 237
ccs 0
cts 82
cp 0
rs 10
wmc 26

7 Methods

Rating   Name   Duplication   Size   Complexity  
A transfer_component_analysis() 0 55 4
B __init__() 51 51 4
A get_params() 0 3 1
B fit() 0 47 6
B kernel() 0 35 6
A is_trained() 0 3 1
B predict() 0 30 4

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
import scipy.stats as st
6
from scipy.sparse.linalg import eigs
7
from scipy.spatial.distance import cdist
8
import sklearn as sk
9
from sklearn.svm import LinearSVC
10
from sklearn.linear_model import LogisticRegression, LinearRegression
11
from sklearn.model_selection import cross_val_predict
12
from os.path import basename
13
14
from .util import is_pos_def
15
16
17
class TransferComponentClassifier(object):
18
    """
19
    Class of classifiers based on Transfer Component Analysis.
20
21
    Methods contain component analysis and general utilities.
22
    """
23
24 View Code Duplication
    def __init__(self, loss='logistic', l2=1.0, mu=1.0, num_components=1,
0 ignored issues
show
Duplication introduced
This code seems to be duplicated in your project.
Loading history...
25
                 kernel_type='rbf', bandwidth=1.0, order=2.0):
26
        """
27
        Select a particular type of transfer component classifier.
28
29
        INPUT   (1) str 'loss': loss function for weighted classifier, options:
30
                    'logistic', 'quadratic', 'hinge' (def: 'logistic')
31
                (2) float 'l2': l2-regularization parameter value (def:0.01)
32
                (3) float 'mu': trade-off parameter (def: 1.0)
33
                (4) int 'num_components': number of transfer components to
34
                    maintain (def: 1)
35
                (5) str 'kernel_type': type of kernel to use, options: 'rbf'
36
                    (def: 'rbf')
37
                (6) float 'bandwidth': kernel bandwidth for transfer component
38
                    analysis (def: 1.0)
39
                (7) float 'order': order of polynomial for kernel (def: 2.0)
40
        """
41
        self.loss = loss
42
        self.l2 = l2
43
        self.mu = mu
44
        self.num_components = num_components
45
46
        self.kernel_type = kernel_type
47
        self.bandwidth = bandwidth
48
        self.order = order
49
50
        # Initialize untrained classifiers
51
        if self.loss == 'logistic':
52
            # Logistic regression model
53
            self.clf = LogisticRegression()
54
        elif self.loss == 'quadratic':
55
            # Least-squares model
56
            self.clf = LinearRegression()
57
        elif self.loss == 'hinge':
58
            # Linear support vector machine
59
            self.clf = LinearSVC()
60
        else:
61
            # Other loss functions are not implemented
62
            raise NotImplementedError
63
64
        # Maintain source and transfer data for computing kernels
65
        self.XZ = ''
66
67
        # Maintain transfer components
68
        self.C = ''
69
70
        # Whether model has been trained
71
        self.is_trained = False
72
73
        # Dimensionality of training data
74
        self.train_data_dim = ''
75
76
    def kernel(self, X, Z, type='rbf', order=2, bandwidth=1.0):
77
        """
78
        Compute kernel for given data set.
79
80
        INPUT   (1) array 'X': data set (N samples by D features)
81
                (2) array 'Z': data set (M samples by D features)
82
                (3) str 'type': type of kernel, options: 'linear',
83
                    'polynomial', 'rbf', 'sigmoid' (def: 'linear')
84
                (4) float 'order': order of polynomial to use for the
85
                    polynomial kernel (def: 2.0)
86
                (5) float 'bandwidth': kernel bandwidth (def: 1.0)
87
        OUTPUT  (1) array: kernel matrix (N+M by N+M)
88
        """
89
        # Data shapes
90
        N, DX = X.shape
91
        M, DZ = Z.shape
92
93
        # Assert equivalent dimensionalities
94
        assert DX == DZ
95
96
        # Select type of kernel to compute
97
        if type == 'linear':
98
            # Linear kernel is data outer product
99
            return np.dot(X, Z.T)
100
        elif type == 'polynomial':
101
            # Polynomial kernel is an exponentiated data outer product
102
            return (np.dot(X, Z.T) + 1)**p
103
        elif type == 'rbf':
104
            # Radial basis function kernel
105
            return np.exp(-cdist(X, Z) / (2.*bandwidth**2))
106
        elif type == 'sigmoid':
107
            # Sigmoidal kernel
108
            return 1./(1 + np.exp(np.dot(X, Z.T)))
109
        else:
110
            raise NotImplementedError
111
112
    def transfer_component_analysis(self, X, Z):
113
        """
114
        Transfer Component Analysis.
115
116
        INPUT   (1) array 'X': source data set (N samples by D features)
117
                (2) array 'Z': target data set (M samples by D features)
118
        OUTPUT  (1) array 'C': transfer components (D features
119
                    by num_components)
120
                (2) array 'K': source and target data kernel distances
121
        """
122
        # Data shapes
123
        N, DX = X.shape
124
        M, DZ = Z.shape
125
126
        # Assert equivalent dimensionalities
127
        assert DX == DZ
128
129
        # Compute kernel matrix
130
        XZ = np.concatenate((X, Z), axis=0)
131
        K = self.kernel(XZ, XZ, type=self.kernel_type,
132
                        bandwidth=self.bandwidth)
133
134
        # Ensure positive-definiteness
135
        if not is_pos_def(K):
136
            print('Warning: covariate matrices not PSD.')
137
138
            regct = -6
139
            while not is_pos_def(K):
140
                print('Adding regularization: ' + str(10**regct))
141
142
                # Add regularization
143
                K += np.eye(N + M)*10.**regct
144
145
                # Increment regularization counter
146
                regct += 1
147
148
        # Normalization matrix
149
        L = np.vstack((np.hstack((np.ones((N, N))/N**2,
150
                                  -1*np.ones((N, M))/(N*M))),
151
                       np.hstack((-1*np.ones((M, N))/(N*M),
152
                                  np.ones((M, M))/M**2))))
153
154
        # Centering matrix
155
        H = np.eye(N + M) - np.ones((N + M, N + M)) / float(N + M)
156
157
        # Matrix Lagrangian objective function: (I + mu*K*L*K)^{-1}*K*H*K
158
        J = np.dot(np.linalg.inv(np.eye(N + M) +
159
                   self.mu*np.dot(np.dot(K, L), K)),
160
                   np.dot(np.dot(K, H), K))
161
162
        # Eigenvector decomposition as solution to trace minimization
163
        _, C = eigs(J, k=self.num_components)
164
165
        # Discard imaginary numbers (possible computation issue)
166
        return np.real(C), K
167
168
    def fit(self, X, y, Z):
169
        """
170
        Fit/train a classifier on data mapped onto transfer components.
171
172
        INPUT   (1) array 'X': source data (N samples by D features)
173
                (2) array 'y': source labels (N samples by 1)
174
                (3) array 'Z': target data (M samples by D features)
175
        OUTPUT
176
        """
177
        # Data shapes
178
        N, DX = X.shape
179
        M, DZ = Z.shape
180
181
        # Assert equivalent dimensionalities
182
        assert DX == DZ
183
184
        # Assert correct number of components for given dataset
185
        assert self.num_components <= N + M - 1
186
187
        # Maintain source and target data for later kernel computations
188
        self.XZ = np.concatenate((X, Z), axis=0)
189
190
        # Transfer component analysis
191
        self.C, K = self.transfer_component_analysis(X, Z)
192
193
        # Map source data onto transfer components
194
        X = np.dot(K[:N, :], self.C)
195
196
        # Train a weighted classifier
197
        if self.loss == 'logistic':
198
            # Logistic regression model with sample weights
199
            self.clf.fit(X, y)
200
        elif self.loss == 'quadratic':
201
            # Least-squares model with sample weights
202
            self.clf.fit(X, y)
203
        elif self.loss == 'hinge':
204
            # Linear support vector machine with sample weights
205
            self.clf.fit(X, y)
206
        else:
207
            # Other loss functions are not implemented
208
            raise NotImplementedError
209
210
        # Mark classifier as trained
211
        self.is_trained = True
212
213
        # Store training data dimensionality
214
        self.train_data_dim = DX
215
216
    def predict(self, Z_):
217
        """
218
        Make predictions on new dataset.
219
220
        INPUT   (1) array 'Z_': new data set (M samples by D features)
221
        OUTPUT  (2) array 'preds': label predictions (M samples by 1)
222
        """
223
        # Data shape
224
        M, D = Z_.shape
225
226
        # If classifier is trained, check for same dimensionality
227
        if self.is_trained:
228
            assert self.train_data_dim == D
229
230
        # Compute kernel for new data
231
        K = self.kernel(Z_, self.XZ, type=self.kernel_type,
232
                        bandwidth=self.bandwidth, order=self.order)
233
234
        # Map new data onto transfer components
235
        Z_ = np.dot(K, self.C)
236
237
        # Call scikit's predict function
238
        preds = self.clf.predict(Z_)
239
240
        # For quadratic loss function, correct predictions
241
        if self.loss == 'quadratic':
242
            preds = (np.sign(preds)+1)/2.
243
244
        # Return predictions array
245
        return preds
246
247
    def get_params(self):
248
        """Get classifier parameters."""
249
        return self.clf.get_params()
250
251
    def is_trained(self):
252
        """Check whether classifier is trained."""
253
        return self.is_trained
254