SDA.fit()   F
last analyzed

Complexity

Conditions 26

Size

Total Lines 148

Duplication

Lines 44
Ratio 29.73 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 26
c 2
b 0
f 0
dl 44
loc 148
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like SDA.fit() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import logging
2
import numpy as np
3
import tensorflow as tf
4
from .layers import AutoencoderLayer, HiddenLayer, SoftmaxLayer
5
from .injectors import BatchInjector
6
from .criterion import MonitorBased, ConstIterations
7
8
logger = logging.getLogger(__name__)
9
10
11
class SDA:
12
    """Stacked Auto-encoder
13
14
    Args:
15
        num_features (:obj:`int`): Number of features.
16
        num_classes (:obj:`int`): Number of classes.
17
        layers (:obj:`list` of :obj:`int`): Series of hidden auto-encoder layers.
18
        encode_optimizer: Optimizer used for auto-encoding process.
19
        tuning_optimizer: Optimizer used for fine tuning.
20
21
    Attributes:
22
        num_features (:obj:`int`): Number of features.
23
        num_classes (:obj:`int`): Number of classes.
24
        x (:obj:`tensorflow.placeholder`): Input placeholder.
25
        y_ (:obj:`tensorflow.placeholder`): Output placeholder.
26
        inner_layers (:obj:`list`): List of auto-encoder hidden layers.
27
28
    """
29
    def __init__(self, num_features, num_classes, layers, encode_optimizer=None, tuning_optimizer=None):
30
        self.num_features = num_features
31
        self.num_classes = num_classes
32
        with tf.name_scope('input'):
33
            self.x = tf.placeholder(tf.float32, shape=[None, num_features], name='input_x')
34
            self.y_ = tf.placeholder(tf.float32, shape=[None, num_classes], name='input_y')
35
        self.inner_layers = []
36
        self.summaries = []
37
        self.encode_opts = []
38
        if encode_optimizer is None:
39
            self.encode_optimizer = tf.train.AdamOptimizer()
40
        else:
41
            self.encode_optimizer = encode_optimizer
42
        if tuning_optimizer is None:
43
            self.tuning_optimizer = tf.train.AdamOptimizer()
44
        else:
45
            self.tuning_optimizer = tuning_optimizer
46
        # Create Layers
47
        for i in range(len(layers)):
48 View Code Duplication
            if i == 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
49
                # First Layer
50
                self.inner_layers.append(
51
                    AutoencoderLayer(num_features, layers[i], x=self.x, name=('Hidden%d' % i))
52
                )
53
            else:
54
                # inner Layer
55
                self.inner_layers.append(
56
                    AutoencoderLayer(layers[i-1], layers[i], x=self.inner_layers[i-1].y, name=('Hidden%d' % i))
57
                )
58
            self.summaries += self.inner_layers[i].summaries
59
            self.encode_opts.append(
60
                self.encode_optimizer.minimize(self.inner_layers[i].encode_loss,
61
                                               var_list=self.inner_layers[i].variables)
62
            )
63 View Code Duplication
        if num_classes == 1:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
64
            # Output Layers
65
            self.output_layer = HiddenLayer(layers[len(layers) - 1], num_classes, x=self.inner_layers[len(layers)-1].y,
66
                                            name='Output', activation_fn=tf.sigmoid)
67
            # Predicted Probability
68
            self.y = self.output_layer.y
69
            self.y_class = tf.cast(tf.greater_equal(self.y, 0.5), tf.float32)
70
            # Loss
71
            self.loss = tf.reduce_mean(
72
                tf.nn.sigmoid_cross_entropy_with_logits(self.output_layer.logits, self.y_,
73
                                                        name='SigmoidCrossEntropyLoss')
74
            )
75
            self.correct_prediction = tf.equal(self.y_class, self.y_)
76
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
77
        else:
78
            # Output Layers
79
            self.output_layer = SoftmaxLayer(layers[len(layers) - 1], num_classes, x=self.inner_layers[len(layers)-1].y,
80
                                             name='OutputLayer')
81
            # Predicted Probability
82
            self.y = self.output_layer.y
83
            self.y_class = tf.argmax(self.y, 1)
84
            # Loss
85
            self.loss = tf.reduce_mean(
86
                tf.nn.softmax_cross_entropy_with_logits(logits=self.output_layer.logits, labels=self.y_,
87
                                                        name='SoftmaxCrossEntropyLoss')
88
            )
89
            self.correct_prediction = tf.equal(self.y_class, tf.argmax(self.y_, 1))
90
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
91
        self.summaries.append(tf.summary.scalar('cross_entropy', self.loss))
92
        self.summaries.append(tf.summary.scalar('accuracy', self.accuracy))
93
        self.summaries += self.output_layer.summaries
94
        with tf.name_scope('train'):
95
            self.fine_tuning = self.tuning_optimizer.minimize(self.loss)
96
        self.merged = tf.summary.merge(self.summaries)
97
        self.sess = None
98
99
    def fit(self, x, y, batch_size=100,
100
            pretrain_iter_num=100, pretrain_criterion='const_iterations',
101
            tuning_iter_num=100, tuning_criterion='const_iterations',
102
            summaries_dir=None, test_x=None, test_y=None, summary_interval=10,
103
            session=None):
104
        """Fit the model to the dataset
105
106
        Args:
107
            x (:obj:`numpy.ndarray`): Input features of shape (num_samples, num_features).
108
            y (:obj:`numpy.ndarray`): Corresponding Labels of shape (num_samples) for binary classification,
109
                or (num_samples, num_classes) for multi-class classification.
110
            batch_size (:obj:`int`): Batch size used in gradient descent.
111
            pretrain_iter_num (:obj:`int`): Number of const iterations or search depth for monitor based stopping
112
                criterion in pre-training stage
113
            pretrain_criterion (:obj:`str`): Stopping criteria in pre-training stage ('const_iterations' or
114
                'monitor_based')
115
            tuning_iter_num (:obj:`int`): Number of const iterations or search depth for monitor based stopping
116
                criterion in fine-tuning stage
117
            tuning_criterion (:obj:`str`): Stopping criteria in fine-tuning stage ('const_iterations' or
118
                'monitor_based')
119
            summaries_dir (:obj:`str`): Path of the directory to store summaries and saved values.
120
            summary_interval (:obj:`int`): The step interval to export variable summaries.
121
            test_x (:obj:`numpy.ndarray`): Test feature array used for monitoring training progress.
122
            test_y (:obj:`numpy.ndarray): Test label array used for monitoring training progress.
123
            session (:obj:`tensorflow.Session`): Session to run training functions.
124
        """
125
        if session is None:
126
            if self.sess is None:
127
                session = tf.Session()
128
                self.sess = session
129
            else:
130
                session = self.sess
131
        session.run(tf.global_variables_initializer())
132
        # Pre-training stage: layer by layer
133
        for j in range(len(self.inner_layers)):
134
            current_layer = self.inner_layers[j]
135
            if summaries_dir is not None:
136
                layer_summaries_dir = '%s/pretrain_layer%d' % (summaries_dir, j)
137
                train_writer = tf.summary.FileWriter(layer_summaries_dir + '/train')
138
                test_writer = tf.summary.FileWriter(layer_summaries_dir + '/test')
139
                valid_writer = tf.summary.FileWriter(layer_summaries_dir + '/valid')
140
            # Get Stopping Criterion
141
            if pretrain_criterion == 'const_iterations':
142
                _pretrain_criterion = ConstIterations(num_iters=pretrain_iter_num)
143
                train_x = x
144
                train_y = y
145
            elif pretrain_criterion == 'monitor_based':
146
                num_samples = x.shape[0]
147
                valid_set_len = int(1 / 5 * num_samples)
148
                valid_x = x[num_samples - valid_set_len:num_samples, :]
149
                valid_y = y[num_samples - valid_set_len:num_samples, :]
150
                train_x = x[0:num_samples - valid_set_len, :]
151
                train_y = y[0:num_samples - valid_set_len, :]
152
                _pretrain_criterion = MonitorBased(n_steps=pretrain_iter_num,
153
                                                   monitor_fn=self.get_encode_loss,
154
                                                   monitor_fn_args=(current_layer, valid_x, valid_y),
155
                                                   save_fn=tf.train.Saver().save,
156
                                                   save_fn_args=(session, layer_summaries_dir + '/best.ckpt'))
157
            else:
158
                logger.error('Wrong criterion %s specified.' % pretrain_criterion)
159
                return
160
            injector = BatchInjector(data_x=train_x, data_y=train_y, batch_size=batch_size)
161
            i = 0
162 View Code Duplication
            while _pretrain_criterion.continue_learning():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
163
                batch_x, batch_y = injector.next_batch()
164
                if summaries_dir is not None and (i % summary_interval == 0):
165
                    summary, loss = session.run(
166
                        [current_layer.merged, current_layer.encode_loss],
167
                        feed_dict={self.x: x, self.y_: y}
168
                    )
169
                    train_writer.add_summary(summary, i)
170
                    logger.info('Pre-training Layer %d, Step %d, training loss %g' % (j, i, loss))
171
                    if test_x is not None and test_y is not None:
172
                        summary, loss = session.run(
173
                            [current_layer.merged, current_layer.encode_loss],
174
                            feed_dict={self.x: test_x, self.y_: test_y}
175
                        )
176
                        test_writer.add_summary(summary, i)
177
                        logger.info('Pre-training Layer %d, Step %d, test loss %g' % (j, i, loss))
178
                    if pretrain_criterion == 'monitor_based':
179
                        summary, loss = session.run(
180
                            [current_layer.merged, current_layer.encode_loss],
181
                            feed_dict={self.x: valid_x, self.y_: valid_y}
182
                        )
183
                        valid_writer.add_summary(summary, i)
184
                        logger.info('Pre-training Layer %d, Step %d, valid loss %g' % (j, i, loss))
185
                _ = session.run(self.encode_opts[j], feed_dict={self.x: batch_x, self.y_: batch_y})
186
                i += 1
187
            if pretrain_criterion == 'monitor_based':
188
                tf.train.Saver().restore(session, layer_summaries_dir + '/best.ckpt')
189
            if summaries_dir is not None:
190
                train_writer.close()
191
                test_writer.close()
192
                valid_writer.close()
193
        # Finish all internal layer-by-layer pre-training
194
        # Start fine tuning
195
        if summaries_dir is not None:
196
            tuning_summaries_dir = '%s/fine_tuning' % summaries_dir
197
            train_writer = tf.summary.FileWriter(tuning_summaries_dir + '/train')
198
            test_writer = tf.summary.FileWriter(tuning_summaries_dir + '/test')
199
            valid_writer = tf.summary.FileWriter(tuning_summaries_dir + '/valid')
200
        # Setup Stopping Criterion
201
        if tuning_criterion == 'const_iterations':
202
            _tuning_criterion = ConstIterations(num_iters=pretrain_iter_num)
203
            train_x = x
204
            train_y = y
205
        elif tuning_criterion == 'monitor_based':
206
            num_samples = x.shape[0]
207
            valid_set_len = int(1 / 5 * num_samples)
208
            valid_x = x[num_samples - valid_set_len:num_samples, :]
209
            valid_y = y[num_samples - valid_set_len:num_samples, :]
210
            train_x = x[0:num_samples - valid_set_len, :]
211
            train_y = y[0:num_samples - valid_set_len, :]
212
            _tuning_criterion = MonitorBased(n_steps=pretrain_iter_num,
213
                                             monitor_fn=self.predict_accuracy,
214
                                             monitor_fn_args=(valid_x, valid_y),
215
                                             save_fn=tf.train.Saver().save,
216
                                             save_fn_args=(session, tuning_summaries_dir + '/best.ckpt'))
217
        else:
218
            logger.error('Wrong criterion %s specified.' % pretrain_criterion)
219
            return
220
        injector = BatchInjector(data_x=train_x, data_y=train_y, batch_size=batch_size)
221
        i = 0
222 View Code Duplication
        while _tuning_criterion.continue_learning():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
223
            batch_x, batch_y = injector.next_batch()
224
            if summaries_dir is not None and (i % summary_interval == 0):
225
                summary, loss, accuracy = session.run([self.merged, self.loss, self.accuracy],
226
                                                      feed_dict={self.x: train_x, self.y_: train_y})
227
                train_writer.add_summary(summary, i)
228
                logger.info('Fine-Tuning: Step %d, training accuracy %g, loss %g' % (i, accuracy, loss))
229
                if (test_x is not None) and (test_y is not None):
230
                    merged, accuracy = session.run([self.merged, self.accuracy],
231
                                                   feed_dict={self.x: test_x, self.y_: test_y})
232
                    test_writer.add_summary(merged, i)
233
                    logger.info('Fine-Tuning: Step %d, test accuracy %g' % (i, accuracy))
234
                if tuning_criterion == 'monitor_based':
235
                    merged, accuracy = session.run([self.merged, self.accuracy],
236
                                                   feed_dict={self.x: valid_x, self.y_: valid_y})
237
                    valid_writer.add_summary(merged, i)
238
                    logger.info('Fine-Tuning: Step %d, valid accuracy %g' % (i, accuracy))
239
            _ = session.run(self.fine_tuning, feed_dict={self.x: batch_x, self.y_: batch_y})
240
            i += 1
241
        if tuning_criterion == 'monitor_based':
242
            tf.train.Saver().restore(session, tuning_summaries_dir + '/best.ckpt')
243
        if summaries_dir is not None:
244
            train_writer.close()
245
            test_writer.close()
246
            valid_writer.close()
247
248
    def get_encode_loss(self, layer, x, y, session=None):
249
        """Get encoder loss of layer specified
250
        """
251
        if session is None:
252
            if self.sess is None:
253
                session = tf.Session()
254
                self.sess = session
255
            else:
256
                session = self.sess
257
        return session.run(layer.encode_loss, feed_dict={self.x: x, self.y_: y})
258
259
    def predict_accuracy(self, x, y, session=None):
260
        """Get Accuracy given feature array and corresponding labels
261
        """
262
        if session is None:
263
            if self.sess is None:
264
                session = tf.Session()
265
                self.sess = session
266
            else:
267
                session = self.sess
268
        return session.run(self.accuracy, feed_dict={self.x: x, self.y_: y})
269
270
    def predict_proba(self, x, session=None):
271
        """Predict probability (Softmax)
272
        """
273
        if session is None:
274
            if self.sess is None:
275
                session = tf.Session()
276
                self.sess = session
277
            else:
278
                session = self.sess
279
        return session.run(self.y, feed_dict={self.x: x})
280
281
    def predict(self, x, session=None):
282
        if session is None:
283
            if self.sess is None:
284
                session = tf.Session()
285
                self.sess = session
286
            else:
287
                session = self.sess
288
        return session.run(self.y_class, feed_dict={self.x: x})
289