Completed
Push — master ( f3d068...88fa67 )
by Wouter
03:44
created

SubspaceAlignedClassifier.fit()   B

Complexity

Conditions 5

Size

Total Lines 58

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 12
CRAP Score 5.9256

Importance

Changes 0
Metric Value
cc 5
c 0
b 0
f 0
dl 0
loc 58
ccs 12
cts 18
cp 0.6667
crap 5.9256
rs 8.392

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4 1
import numpy as np
5 1
import scipy.stats as st
6 1
from scipy.sparse.linalg import eigs
7 1
from scipy.spatial.distance import cdist
8 1
import sklearn as sk
9 1
from sklearn.decomposition import PCA
10 1
from sklearn.svm import LinearSVC
11 1
from sklearn.linear_model import LogisticRegression, LinearRegression
12 1
from sklearn.model_selection import cross_val_predict
13 1
from os.path import basename
14
15 1
from .util import is_pos_def
16
17
18 1
class SubspaceAlignedClassifier(object):
19
    """
20
    Class of classifiers based on Subspace Alignment.
21
22
    Methods contain the alignment itself, classifiers and general utilities.
23
24
    Examples
25
    --------
26
    | >>>> X = np.random.randn(10, 2)
27
    | >>>> y = np.vstack((-np.ones((5,)), np.ones((5,))))
28
    | >>>> Z = np.random.randn(10, 2)
29
    | >>>> clf = SubspaceAlignedClassifier()
30
    | >>>> clf.fit(X, y, Z)
31
    | >>>> preds = clf.predict(Z)
32
    """
33
34 1 View Code Duplication
    def __init__(self, loss='logistic', l2=1.0, num_components=1):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
35
        """
36
        Select a particular type of subspace aligned classifier.
37
38
        Parameters
39
        ----------
40
        loss : str
41
            loss function for weighted classifier, options: 'logistic',
42
            'quadratic', 'hinge' (def: 'logistic')
43
        l2 : float
44
            l2-regularization parameter value (def:0.01)
45
        num_components : int
46
            number of transfer components to maintain (def: 1)
47
48
        Returns
49
        -------
50
        None
51
52
        """
53 1
        self.loss = loss
54 1
        self.l2 = l2
55 1
        self.num_components = num_components
56
57
        # Initialize untrained classifiers
58 1
        if self.loss == 'logistic':
59
            # Logistic regression model
60 1
            self.clf = LogisticRegression()
61
        elif self.loss == 'quadratic':
62
            # Least-squares model
63
            self.clf = LinearRegression()
64
        elif self.loss == 'hinge':
65
            # Linear support vector machine
66
            self.clf = LinearSVC()
67
        else:
68
            # Other loss functions are not implemented
69
            raise NotImplementedError('Loss function not implemented.')
70
71
        # Whether model has been trained
72 1
        self.is_trained = False
73
74
        # Dimensionality of training data
75 1
        self.train_data_dim = ''
76
77 1
    def subspace_alignment(self, X, Z, num_components=1):
78
        """
79
        Compute subspace and alignment matrix.
80
81
        Parameters
82
        ----------
83
        X : array
84
            source data set (N samples by D features)
85
        Z : array
86
            target data set (M samples by D features)
87
        num_components : int
88
            number of components (def: 1)
89
90
        Returns
91
        -------
92
        V : array
93
            transformation matrix (D features by D features)
94
        CX : array
95
            source principal component coefficients
96
        CZ : array
97
            target principal component coefficients
98
99
        """
100
        # Data shapes
101 1
        N, DX = X.shape
102 1
        M, DZ = Z.shape
103
104
        # Assert equivalent dimensionalities
105 1
        if not DX == DZ:
106
            raise ValueError('Dimensionalities of X and Z should be equal.')
107
108
        # Compute principal components
109 1
        CX = PCA(n_components=num_components, whiten=True).fit(X).components_.T
110 1
        CZ = PCA(n_components=num_components, whiten=True).fit(Z).components_.T
111
112
        # Aligned source components
113 1
        V = np.dot(CX.T, CZ)
114
115
        # Return transformation matrix and principal component coefficients
116 1
        return V, CX, CZ
117
118 1
    def fit(self, X, y, Z):
119
        """
120
        Fit/train a classifier on data mapped onto transfer components.
121
122
        Parameters
123
        ----------
124
        X : array
125
            source data (N samples by D features)
126
        y : array
127
            source labels (N samples by 1)
128
        Z : array
129
            target data (M samples by D features)
130
131
        Returns
132
        -------
133
        None
134
135
        """
136
        # Data shapes
137 1
        N, DX = X.shape
138 1
        M, DZ = Z.shape
139
140
        # Assert equivalent dimensionalities
141 1
        if not DX == DZ:
142
            raise ValueError('Dimensionalities of X and Z should be equal.')
143
144
        # Transfer component analysis
145 1
        V, CX, CZ = self.subspace_alignment(X, Z,
146
                                            num_components=self.num_components)
147
148
        # Store target subspace
149 1
        self.target_subspace = CZ
150
151
        # Map source data onto source principal components
152 1
        X = np.dot(X, CX)
153
154
        # Align source data to target subspace
155 1
        X = np.dot(X, V)
156
157
        # Train a weighted classifier
158 1
        if self.loss == 'logistic':
159
            # Logistic regression model with sample weights
160 1
            self.clf.fit(X, y)
161
        elif self.loss == 'quadratic':
162
            # Least-squares model with sample weights
163
            self.clf.fit(X, y)
164
        elif self.loss == 'hinge':
165
            # Linear support vector machine with sample weights
166
            self.clf.fit(X, y)
167
        else:
168
            # Other loss functions are not implemented
169
            raise NotImplementedError
170
171
        # Mark classifier as trained
172 1
        self.is_trained = True
173
174
        # Store training data dimensionality
175 1
        self.train_data_dim = DX
176
177 1 View Code Duplication
    def predict(self, Z, whiten=False):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
178
        """
179
        Make predictions on new dataset.
180
181
        Parameters
182
        ----------
183
        Z : array
184
            new data set (M samples by D features)
185
        whiten : boolean
186
            whether to whiten new data (def: false)
187
188
        Returns
189
        -------
190
        preds : array
191
            label predictions (M samples by 1)
192
193
        """
194
        # Data shape
195 1
        M, D = Z.shape
196
197
        # If classifier is trained, check for same dimensionality
198 1
        if self.is_trained:
199 1
            if not self.train_data_dim == D:
200
                raise ValueError('''Test data is of different dimensionality
201
                                 than training data.''')
202
203
        # Check for need to whiten data beforehand
204 1
        if whiten:
205
            Z = st.zscore(Z)
206
207
        # Map new target data onto target subspace
208 1
        Z = np.dot(Z, self.target_subspace)
209
210
        # Call scikit's predict function
211 1
        preds = self.clf.predict(Z)
212
213
        # For quadratic loss function, correct predictions
214 1
        if self.loss == 'quadratic':
215
            preds = (np.sign(preds)+1)/2.
216
217
        # Return predictions array
218 1
        return preds
219
220 1
    def get_params(self):
221
        """Get classifier parameters."""
222
        return self.clf.get_params()
223
224 1
    def is_trained(self):
225
        """Check whether classifier is trained."""
226
        return self.is_trained
227