MLP.fit()   F
last analyzed

Complexity

Conditions 13

Size

Total Lines 79

Duplication

Lines 24
Ratio 30.38 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 13
c 2
b 0
f 0
dl 24
loc 79
rs 2.0788

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like MLP.fit() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import os
2
import logging
3
import numpy as np
4
import tensorflow as tf
5
from .layers import HiddenLayer, SoftmaxLayer
6
from .injectors import BatchInjector
7
from .criterion import MonitorBased, ConstIterations
8
9
logger = logging.getLogger(__name__)
10
11
12
class MLP:
13
    """Multi-Layer Perceptron
14
15
    Args:
16
        num_features (:obj:`int`): Number of features.
17
        num_classes (:obj:`int`): Number of classes.
18
        layers (:obj:`list` of :obj:`int`): Series of hidden auto-encoder layers.
19
        activation_fn: activation function used in hidden layer.
20
        optimizer: Optimizer used for updating weights.
21
22
    Attributes:
23
        num_features (:obj:`int`): Number of features.
24
        num_classes (:obj:`int`): Number of classes.
25
        x (:obj:`tensorflow.placeholder`): Input placeholder.
26
        y_ (:obj:`tensorflow.placeholder`): Output placeholder.
27
        inner_layers (:obj:`list`): List of inner hidden layers.
28
        summaries (:obj:`list`): List of tensorflow summaries.
29
        output_layer: Output softmax layer for multi-class classification, sigmoid for binary classification
30
        y (:obj:`tensorflow.Tensor`): Softmax/Sigmoid output layer output tensor.
31
        y_class (:obj:`tensorflow.Tensor`): Tensor to get class label from output layer.
32
        loss (:obj:`tensorflow.Tensor`): Tensor that represents the cross-entropy loss.
33
        correct_prediction (:obj:`tensorflow.Tensor`): Tensor that represents the correctness of classification result.
34
        accuracy (:obj:`tensorflow.Tensor`): Tensor that represents the accuracy of the classifier (exact matching
35
            ratio in multi-class classification)
36
        optimizer: Optimizer used for updating weights.
37
        fit_step (:obj:`tensorflow.Tensor`): Tensor to update weights based on the optimizer algorithm provided.
38
        sess: Tensorflow session.
39
        merged: Merged summaries.
40
    """
41
    def __init__(self, num_features, num_classes, layers, activation_fn=tf.sigmoid, optimizer=None):
42
        self.num_features = num_features
43
        self.num_classes = num_classes
44
        with tf.name_scope('input'):
45
            self.x = tf.placeholder(tf.float32, shape=[None, num_features], name='input_x')
46
            self.y_ = tf.placeholder(tf.float32, shape=[None, num_classes], name='input_y')
47
        self.inner_layers = []
48
        self.summaries = []
49
        # Create Layers
50
        for i in range(len(layers)):
51 View Code Duplication
            if i == 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
52
                # First Layer
53
                self.inner_layers.append(
54
                    HiddenLayer(num_features, layers[i], x=self.x, name=('Hidden%d' % i), activation_fn=activation_fn)
55
                )
56
            else:
57
                # inner Layer
58
                self.inner_layers.append(
59
                    HiddenLayer(layers[i-1], layers[i], x=self.inner_layers[i-1].y,
60
                                name=('Hidden%d' % i), activation_fn=activation_fn)
61
                )
62
            self.summaries += self.inner_layers[i].summaries
63 View Code Duplication
        if num_classes == 1:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
64
            # Output Layers
65
            self.output_layer = HiddenLayer(layers[len(layers) - 1], num_classes, x=self.inner_layers[len(layers)-1].y,
66
                                            name='Output', activation_fn=tf.sigmoid)
67
            # Predicted Probability
68
            self.y = self.output_layer.y
69
            self.y_class = tf.cast(tf.greater_equal(self.y, 0.5), tf.float32)
70
            # Loss
71
            self.loss = tf.reduce_mean(
72
                tf.nn.sigmoid_cross_entropy_with_logits(logits=self.output_layer.logits, labels=self.y_,
73
                                                        name='SigmoidCrossEntropyLoss')
74
            )
75
            self.correct_prediction = tf.equal(self.y_class, self.y_)
76
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
77
        else:
78
            # Output Layers
79
            self.output_layer = SoftmaxLayer(layers[len(layers) - 1], num_classes, x=self.inner_layers[len(layers)-1].y,
80
                                             name='OutputLayer')
81
            # Predicted Probability
82
            self.y = self.output_layer.y
83
            self.y_class = tf.argmax(self.y, 1)
84
            # Loss
85
            self.loss = tf.reduce_mean(
86
                tf.nn.softmax_cross_entropy_with_logits(logits=self.output_layer.logits, labels=self.y_,
87
                                                        name='SoftmaxCrossEntropyLoss')
88
            )
89
            self.correct_prediction = tf.equal(self.y_class, tf.argmax(self.y_, 1))
90
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
91
        self.summaries.append(tf.summary.scalar('cross_entropy', self.loss))
92
        self.summaries.append(tf.summary.scalar('accuracy', self.accuracy))
93
        self.summaries += self.output_layer.summaries
94
        if optimizer is None:
95
            self.optimizer = tf.train.AdamOptimizer()
96
        else:
97
            self.optimizer = optimizer
98
        with tf.name_scope('train'):
99
            self.fit_step = self.optimizer.minimize(self.loss)
100
        self.merged = tf.summary.merge(self.summaries)
101
        self.sess = None
102
103
    def fit(self, x, y, batch_size=100, iter_num=100,
104
            summaries_dir=None, summary_interval=100,
105
            test_x=None, test_y=None,
106
            session=None, criterion='const_iteration'):
107
        """Fit the model to the dataset
108
109
        Args:
110
            x (:obj:`numpy.ndarray`): Input features of shape (num_samples, num_features).
111
            y (:obj:`numpy.ndarray`): Corresponding Labels of shape (num_samples) for binary classification,
112
                or (num_samples, num_classes) for multi-class classification.
113
            batch_size (:obj:`int`): Batch size used in gradient descent.
114
            iter_num (:obj:`int`): Number of training iterations for const iterations, step depth for monitor based
115
                stopping criterion.
116
            summaries_dir (:obj:`str`): Path of the directory to store summaries and saved values.
117
            summary_interval (:obj:`int`): The step interval to export variable summaries.
118
            test_x (:obj:`numpy.ndarray`): Test feature array used for monitoring training progress.
119
            test_y (:obj:`numpy.ndarray): Test label array used for monitoring training progress.
120
            session (:obj:`tensorflow.Session`): Session to run training functions.
121
            criterion (:obj:`str`): Stopping criteria. 'const_iterations' or 'monitor_based'
122
        """
123
        if session is None:
124
            if self.sess is None:
125
                session = tf.Session()
126
                self.sess = session
127
            else:
128
                session = self.sess
129
        if summaries_dir is not None:
130
            train_writer = tf.summary.FileWriter(summaries_dir + '/train', session.graph)
131
            test_writer = tf.summary.FileWriter(summaries_dir + '/test')
132
            valid_writer = tf.summary.FileWriter(summaries_dir + '/valid')
133
        session.run(tf.global_variables_initializer())
134
        # Get Stopping Criterion
135
        if criterion == 'const_iteration':
136
            _criterion = ConstIterations(num_iters=iter_num)
137
        elif criterion == 'monitor_based':
138
            num_samples = x.shape[0]
139
            valid_set_len = int(1/5 * num_samples)
140
            valid_x = x[num_samples-valid_set_len:num_samples, :]
141
            valid_y = y[num_samples-valid_set_len:num_samples, :]
142
            x = x[0:num_samples-valid_set_len, :]
143
            y = y[0:num_samples-valid_set_len, :]
144
            _criterion = MonitorBased(n_steps=iter_num,
145
                                      monitor_fn=self.predict_accuracy, monitor_fn_args=(valid_x, valid_y),
146
                                      save_fn=tf.train.Saver().save,
147
                                      save_fn_args=(session, summaries_dir + '/best.ckpt'))
148
        else:
149
            logger.error('Wrong criterion %s specified.' % criterion)
150
            return
151
        # Setup batch injector
152
        injector = BatchInjector(data_x=x, data_y=y, batch_size=batch_size)
153
        i = 0
154
        train_accuracy = 0
155 View Code Duplication
        while _criterion.continue_learning():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
156
            batch_x, batch_y = injector.next_batch()
157
            if summaries_dir is not None and (i % summary_interval == 0):
158
                summary, loss, accuracy = session.run([self.merged, self.loss, self.accuracy],
159
                                                      feed_dict={self.x: x, self.y_: y})
160
                train_writer.add_summary(summary, i)
161
                train_accuracy = accuracy
162
                logger.info('Step %d, train_set accuracy %g, loss %g' % (i, accuracy, loss))
163
                if (test_x is not None) and (test_y is not None):
164
                    merged, accuracy = session.run([self.merged, self.accuracy],
165
                                                   feed_dict={self.x: test_x, self.y_: test_y})
166
                    test_writer.add_summary(merged, i)
167
                    logger.info('test_set accuracy %g' % accuracy)
168
                if criterion == 'monitor_based':
169
                    merged, accuracy = session.run([self.merged, self.accuracy],
170
                                                   feed_dict={self.x: valid_x, self.y_: valid_y})
171
                    valid_writer.add_summary(merged, i)
172
                    logger.info('valid_set accuracy %g' % accuracy)
173
            loss, accuracy, _ = session.run([self.loss, self.accuracy, self.fit_step],
174
                                            feed_dict={self.x: batch_x, self.y_: batch_y})
175
            #logger.info('Step %d, training accuracy %g, loss %g' % (i, accuracy, loss))
176
            #_ = session.run(self.fit_step, feed_dict={self.x: batch_x, self.y_: batch_y})
177
            #logger.info('Step %d, training accuracy %g, loss %g' % (i, accuracy, loss))
178
            i += 1
179
        if criterion == 'monitor_based':
180
            tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt'))
181
        logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch)
182
183
    def predict_accuracy(self, x, y, session=None):
184
        """Get Accuracy given feature array and corresponding labels
185
        """
186
        if session is None:
187
            if self.sess is None:
188
                session = tf.Session()
189
                self.sess = session
190
            else:
191
                session = self.sess
192
        return session.run(self.accuracy, feed_dict={self.x: x, self.y_: y})
193
194
    def predict_proba(self, x, session=None):
195
        """Predict probability (Softmax)
196
        """
197
        if session is None:
198
            if self.sess is None:
199
                session = tf.Session()
200
                self.sess = session
201
            else:
202
                session = self.sess
203
        return session.run(self.y, feed_dict={self.x: x})
204
205
    def predict(self, x, session=None):
206
        if session is None:
207
            if self.sess is None:
208
                session = tf.Session()
209
                self.sess = session
210
            else:
211
                session = self.sess
212
        return session.run(self.y_class, feed_dict={self.x: x})
213