Conditions | 26 |
Total Lines | 148 |
Lines | 44 |
Ratio | 29.73 % |
Changes | 2 | ||
Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like SDA.fit() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import logging |
||
99 | def fit(self, x, y, batch_size=100, |
||
100 | pretrain_iter_num=100, pretrain_criterion='const_iterations', |
||
101 | tuning_iter_num=100, tuning_criterion='const_iterations', |
||
102 | summaries_dir=None, test_x=None, test_y=None, summary_interval=10, |
||
103 | session=None): |
||
104 | """Fit the model to the dataset |
||
105 | |||
106 | Args: |
||
107 | x (:obj:`numpy.ndarray`): Input features of shape (num_samples, num_features). |
||
108 | y (:obj:`numpy.ndarray`): Corresponding Labels of shape (num_samples) for binary classification, |
||
109 | or (num_samples, num_classes) for multi-class classification. |
||
110 | batch_size (:obj:`int`): Batch size used in gradient descent. |
||
111 | pretrain_iter_num (:obj:`int`): Number of const iterations or search depth for monitor based stopping |
||
112 | criterion in pre-training stage |
||
113 | pretrain_criterion (:obj:`str`): Stopping criteria in pre-training stage ('const_iterations' or |
||
114 | 'monitor_based') |
||
115 | tuning_iter_num (:obj:`int`): Number of const iterations or search depth for monitor based stopping |
||
116 | criterion in fine-tuning stage |
||
117 | tuning_criterion (:obj:`str`): Stopping criteria in fine-tuning stage ('const_iterations' or |
||
118 | 'monitor_based') |
||
119 | summaries_dir (:obj:`str`): Path of the directory to store summaries and saved values. |
||
120 | summary_interval (:obj:`int`): The step interval to export variable summaries. |
||
121 | test_x (:obj:`numpy.ndarray`): Test feature array used for monitoring training progress. |
||
122 | test_y (:obj:`numpy.ndarray): Test label array used for monitoring training progress. |
||
123 | session (:obj:`tensorflow.Session`): Session to run training functions. |
||
124 | """ |
||
125 | if session is None: |
||
126 | if self.sess is None: |
||
127 | session = tf.Session() |
||
128 | self.sess = session |
||
129 | else: |
||
130 | session = self.sess |
||
131 | session.run(tf.global_variables_initializer()) |
||
132 | # Pre-training stage: layer by layer |
||
133 | for j in range(len(self.inner_layers)): |
||
134 | current_layer = self.inner_layers[j] |
||
135 | if summaries_dir is not None: |
||
136 | layer_summaries_dir = '%s/pretrain_layer%d' % (summaries_dir, j) |
||
137 | train_writer = tf.summary.FileWriter(layer_summaries_dir + '/train') |
||
138 | test_writer = tf.summary.FileWriter(layer_summaries_dir + '/test') |
||
139 | valid_writer = tf.summary.FileWriter(layer_summaries_dir + '/valid') |
||
140 | # Get Stopping Criterion |
||
141 | if pretrain_criterion == 'const_iterations': |
||
142 | _pretrain_criterion = ConstIterations(num_iters=pretrain_iter_num) |
||
143 | train_x = x |
||
144 | train_y = y |
||
145 | elif pretrain_criterion == 'monitor_based': |
||
146 | num_samples = x.shape[0] |
||
147 | valid_set_len = int(1 / 5 * num_samples) |
||
148 | valid_x = x[num_samples - valid_set_len:num_samples, :] |
||
149 | valid_y = y[num_samples - valid_set_len:num_samples, :] |
||
150 | train_x = x[0:num_samples - valid_set_len, :] |
||
151 | train_y = y[0:num_samples - valid_set_len, :] |
||
152 | _pretrain_criterion = MonitorBased(n_steps=pretrain_iter_num, |
||
153 | monitor_fn=self.get_encode_loss, |
||
154 | monitor_fn_args=(current_layer, valid_x, valid_y), |
||
155 | save_fn=tf.train.Saver().save, |
||
156 | save_fn_args=(session, layer_summaries_dir + '/best.ckpt')) |
||
157 | else: |
||
158 | logger.error('Wrong criterion %s specified.' % pretrain_criterion) |
||
159 | return |
||
160 | injector = BatchInjector(data_x=train_x, data_y=train_y, batch_size=batch_size) |
||
161 | i = 0 |
||
162 | View Code Duplication | while _pretrain_criterion.continue_learning(): |
|
163 | batch_x, batch_y = injector.next_batch() |
||
164 | if summaries_dir is not None and (i % summary_interval == 0): |
||
165 | summary, loss = session.run( |
||
166 | [current_layer.merged, current_layer.encode_loss], |
||
167 | feed_dict={self.x: x, self.y_: y} |
||
168 | ) |
||
169 | train_writer.add_summary(summary, i) |
||
170 | logger.info('Pre-training Layer %d, Step %d, training loss %g' % (j, i, loss)) |
||
171 | if test_x is not None and test_y is not None: |
||
172 | summary, loss = session.run( |
||
173 | [current_layer.merged, current_layer.encode_loss], |
||
174 | feed_dict={self.x: test_x, self.y_: test_y} |
||
175 | ) |
||
176 | test_writer.add_summary(summary, i) |
||
177 | logger.info('Pre-training Layer %d, Step %d, test loss %g' % (j, i, loss)) |
||
178 | if pretrain_criterion == 'monitor_based': |
||
179 | summary, loss = session.run( |
||
180 | [current_layer.merged, current_layer.encode_loss], |
||
181 | feed_dict={self.x: valid_x, self.y_: valid_y} |
||
182 | ) |
||
183 | valid_writer.add_summary(summary, i) |
||
184 | logger.info('Pre-training Layer %d, Step %d, valid loss %g' % (j, i, loss)) |
||
185 | _ = session.run(self.encode_opts[j], feed_dict={self.x: batch_x, self.y_: batch_y}) |
||
186 | i += 1 |
||
187 | if pretrain_criterion == 'monitor_based': |
||
188 | tf.train.Saver().restore(session, layer_summaries_dir + '/best.ckpt') |
||
189 | if summaries_dir is not None: |
||
190 | train_writer.close() |
||
191 | test_writer.close() |
||
192 | valid_writer.close() |
||
193 | # Finish all internal layer-by-layer pre-training |
||
194 | # Start fine tuning |
||
195 | if summaries_dir is not None: |
||
196 | tuning_summaries_dir = '%s/fine_tuning' % summaries_dir |
||
197 | train_writer = tf.summary.FileWriter(tuning_summaries_dir + '/train') |
||
198 | test_writer = tf.summary.FileWriter(tuning_summaries_dir + '/test') |
||
199 | valid_writer = tf.summary.FileWriter(tuning_summaries_dir + '/valid') |
||
200 | # Setup Stopping Criterion |
||
201 | if tuning_criterion == 'const_iterations': |
||
202 | _tuning_criterion = ConstIterations(num_iters=pretrain_iter_num) |
||
203 | train_x = x |
||
204 | train_y = y |
||
205 | elif tuning_criterion == 'monitor_based': |
||
206 | num_samples = x.shape[0] |
||
207 | valid_set_len = int(1 / 5 * num_samples) |
||
208 | valid_x = x[num_samples - valid_set_len:num_samples, :] |
||
209 | valid_y = y[num_samples - valid_set_len:num_samples, :] |
||
210 | train_x = x[0:num_samples - valid_set_len, :] |
||
211 | train_y = y[0:num_samples - valid_set_len, :] |
||
212 | _tuning_criterion = MonitorBased(n_steps=pretrain_iter_num, |
||
213 | monitor_fn=self.predict_accuracy, |
||
214 | monitor_fn_args=(valid_x, valid_y), |
||
215 | save_fn=tf.train.Saver().save, |
||
216 | save_fn_args=(session, tuning_summaries_dir + '/best.ckpt')) |
||
217 | else: |
||
218 | logger.error('Wrong criterion %s specified.' % pretrain_criterion) |
||
219 | return |
||
220 | injector = BatchInjector(data_x=train_x, data_y=train_y, batch_size=batch_size) |
||
221 | i = 0 |
||
222 | View Code Duplication | while _tuning_criterion.continue_learning(): |
|
223 | batch_x, batch_y = injector.next_batch() |
||
224 | if summaries_dir is not None and (i % summary_interval == 0): |
||
225 | summary, loss, accuracy = session.run([self.merged, self.loss, self.accuracy], |
||
226 | feed_dict={self.x: train_x, self.y_: train_y}) |
||
227 | train_writer.add_summary(summary, i) |
||
228 | logger.info('Fine-Tuning: Step %d, training accuracy %g, loss %g' % (i, accuracy, loss)) |
||
229 | if (test_x is not None) and (test_y is not None): |
||
230 | merged, accuracy = session.run([self.merged, self.accuracy], |
||
231 | feed_dict={self.x: test_x, self.y_: test_y}) |
||
232 | test_writer.add_summary(merged, i) |
||
233 | logger.info('Fine-Tuning: Step %d, test accuracy %g' % (i, accuracy)) |
||
234 | if tuning_criterion == 'monitor_based': |
||
235 | merged, accuracy = session.run([self.merged, self.accuracy], |
||
236 | feed_dict={self.x: valid_x, self.y_: valid_y}) |
||
237 | valid_writer.add_summary(merged, i) |
||
238 | logger.info('Fine-Tuning: Step %d, valid accuracy %g' % (i, accuracy)) |
||
239 | _ = session.run(self.fine_tuning, feed_dict={self.x: batch_x, self.y_: batch_y}) |
||
240 | i += 1 |
||
241 | if tuning_criterion == 'monitor_based': |
||
242 | tf.train.Saver().restore(session, tuning_summaries_dir + '/best.ckpt') |
||
243 | if summaries_dir is not None: |
||
244 | train_writer.close() |
||
245 | test_writer.close() |
||
246 | valid_writer.close() |
||
247 | |||
289 |