Conditions | 13 |
Total Lines | 79 |
Lines | 24 |
Ratio | 30.38 % |
Changes | 2 | ||
Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like MLP.fit() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import os |
||
103 | def fit(self, x, y, batch_size=100, iter_num=100, |
||
104 | summaries_dir=None, summary_interval=100, |
||
105 | test_x=None, test_y=None, |
||
106 | session=None, criterion='const_iteration'): |
||
107 | """Fit the model to the dataset |
||
108 | |||
109 | Args: |
||
110 | x (:obj:`numpy.ndarray`): Input features of shape (num_samples, num_features). |
||
111 | y (:obj:`numpy.ndarray`): Corresponding Labels of shape (num_samples) for binary classification, |
||
112 | or (num_samples, num_classes) for multi-class classification. |
||
113 | batch_size (:obj:`int`): Batch size used in gradient descent. |
||
114 | iter_num (:obj:`int`): Number of training iterations for const iterations, step depth for monitor based |
||
115 | stopping criterion. |
||
116 | summaries_dir (:obj:`str`): Path of the directory to store summaries and saved values. |
||
117 | summary_interval (:obj:`int`): The step interval to export variable summaries. |
||
118 | test_x (:obj:`numpy.ndarray`): Test feature array used for monitoring training progress. |
||
119 | test_y (:obj:`numpy.ndarray): Test label array used for monitoring training progress. |
||
120 | session (:obj:`tensorflow.Session`): Session to run training functions. |
||
121 | criterion (:obj:`str`): Stopping criteria. 'const_iterations' or 'monitor_based' |
||
122 | """ |
||
123 | if session is None: |
||
124 | if self.sess is None: |
||
125 | session = tf.Session() |
||
126 | self.sess = session |
||
127 | else: |
||
128 | session = self.sess |
||
129 | if summaries_dir is not None: |
||
130 | train_writer = tf.summary.FileWriter(summaries_dir + '/train', session.graph) |
||
131 | test_writer = tf.summary.FileWriter(summaries_dir + '/test') |
||
132 | valid_writer = tf.summary.FileWriter(summaries_dir + '/valid') |
||
133 | session.run(tf.global_variables_initializer()) |
||
134 | # Get Stopping Criterion |
||
135 | if criterion == 'const_iteration': |
||
136 | _criterion = ConstIterations(num_iters=iter_num) |
||
137 | elif criterion == 'monitor_based': |
||
138 | num_samples = x.shape[0] |
||
139 | valid_set_len = int(1/5 * num_samples) |
||
140 | valid_x = x[num_samples-valid_set_len:num_samples, :] |
||
141 | valid_y = y[num_samples-valid_set_len:num_samples, :] |
||
142 | x = x[0:num_samples-valid_set_len, :] |
||
143 | y = y[0:num_samples-valid_set_len, :] |
||
144 | _criterion = MonitorBased(n_steps=iter_num, |
||
145 | monitor_fn=self.predict_accuracy, monitor_fn_args=(valid_x, valid_y), |
||
146 | save_fn=tf.train.Saver().save, |
||
147 | save_fn_args=(session, summaries_dir + '/best.ckpt')) |
||
148 | else: |
||
149 | logger.error('Wrong criterion %s specified.' % criterion) |
||
150 | return |
||
151 | # Setup batch injector |
||
152 | injector = BatchInjector(data_x=x, data_y=y, batch_size=batch_size) |
||
153 | i = 0 |
||
154 | train_accuracy = 0 |
||
155 | View Code Duplication | while _criterion.continue_learning(): |
|
156 | batch_x, batch_y = injector.next_batch() |
||
157 | if summaries_dir is not None and (i % summary_interval == 0): |
||
158 | summary, loss, accuracy = session.run([self.merged, self.loss, self.accuracy], |
||
159 | feed_dict={self.x: x, self.y_: y}) |
||
160 | train_writer.add_summary(summary, i) |
||
161 | train_accuracy = accuracy |
||
162 | logger.info('Step %d, train_set accuracy %g, loss %g' % (i, accuracy, loss)) |
||
163 | if (test_x is not None) and (test_y is not None): |
||
164 | merged, accuracy = session.run([self.merged, self.accuracy], |
||
165 | feed_dict={self.x: test_x, self.y_: test_y}) |
||
166 | test_writer.add_summary(merged, i) |
||
167 | logger.info('test_set accuracy %g' % accuracy) |
||
168 | if criterion == 'monitor_based': |
||
169 | merged, accuracy = session.run([self.merged, self.accuracy], |
||
170 | feed_dict={self.x: valid_x, self.y_: valid_y}) |
||
171 | valid_writer.add_summary(merged, i) |
||
172 | logger.info('valid_set accuracy %g' % accuracy) |
||
173 | loss, accuracy, _ = session.run([self.loss, self.accuracy, self.fit_step], |
||
174 | feed_dict={self.x: batch_x, self.y_: batch_y}) |
||
175 | #logger.info('Step %d, training accuracy %g, loss %g' % (i, accuracy, loss)) |
||
176 | #_ = session.run(self.fit_step, feed_dict={self.x: batch_x, self.y_: batch_y}) |
||
177 | #logger.info('Step %d, training accuracy %g, loss %g' % (i, accuracy, loss)) |
||
178 | i += 1 |
||
179 | if criterion == 'monitor_based': |
||
180 | tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
||
181 | logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch) |
||
182 | |||
213 |