|
@@ 107-128 (lines=22) @@
|
| 104 |
|
monitor_fn=self.predict_accuracy, |
| 105 |
|
monitor_fn_args=(valid_x, valid_y[self.num_steps:, :]), |
| 106 |
|
save_fn=tf.train.Saver().save, |
| 107 |
|
save_fn_args=(session, summaries_dir + '/best.ckpt')) |
| 108 |
|
else: |
| 109 |
|
logger.error('Wrong criterion %s specified.' % criterion) |
| 110 |
|
return |
| 111 |
|
# Train/Test sequence for brief reporting of accuracy and loss |
| 112 |
|
train_seq_x, train_seq_y = BatchSequenceInjector.to_sequence( |
| 113 |
|
self.num_steps, x, y, start=0, end=2000 |
| 114 |
|
) |
| 115 |
|
if (test_x is not None) and (test_y is not None): |
| 116 |
|
test_seq_x, test_seq_y = BatchSequenceInjector.to_sequence( |
| 117 |
|
self.num_steps, test_x, test_y, start=0, end=2000 |
| 118 |
|
) |
| 119 |
|
# Iteration Starts |
| 120 |
|
i = 0 |
| 121 |
|
while _criterion.continue_learning(): |
| 122 |
|
batch_x, batch_y = injector.next_batch() |
| 123 |
|
if summaries_dir is not None and (i % summary_interval == 0): |
| 124 |
|
summary, loss, accuracy = session.run( |
| 125 |
|
[self.merged, self.loss, self.accuracy], |
| 126 |
|
feed_dict={self.x: train_seq_x, self.y_: train_seq_y, |
| 127 |
|
self.init_state: np.zeros((train_seq_x.shape[0], 2 * self.num_units))} |
| 128 |
|
) |
| 129 |
|
train_writer.add_summary(summary, i) |
| 130 |
|
logger.info('Step %d, train_set accuracy %g, loss %g' % (i, accuracy, loss)) |
| 131 |
|
if (test_x is not None) and (test_y is not None): |
|
@@ 130-149 (lines=20) @@
|
| 127 |
|
self.init_state: np.zeros((train_seq_x.shape[0], 2 * self.num_units))} |
| 128 |
|
) |
| 129 |
|
train_writer.add_summary(summary, i) |
| 130 |
|
logger.info('Step %d, train_set accuracy %g, loss %g' % (i, accuracy, loss)) |
| 131 |
|
if (test_x is not None) and (test_y is not None): |
| 132 |
|
merged, accuracy = session.run( |
| 133 |
|
[self.merged, self.accuracy], |
| 134 |
|
feed_dict={self.x: test_seq_x, self.y_: test_seq_y, |
| 135 |
|
self.init_state: np.zeros((test_seq_x.shape[0], 2*self.num_units))}) |
| 136 |
|
test_writer.add_summary(merged, i) |
| 137 |
|
logger.info('test_set accuracy %g' % accuracy) |
| 138 |
|
loss, accuracy, _ = session.run( |
| 139 |
|
[self.loss, self.accuracy, self.fit_step], |
| 140 |
|
feed_dict={self.x: batch_x, self.y_: batch_y, |
| 141 |
|
self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
| 142 |
|
i += 1 |
| 143 |
|
# Finish Iteration |
| 144 |
|
if criterion == 'monitor_based': |
| 145 |
|
tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
| 146 |
|
logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch) |
| 147 |
|
|
| 148 |
|
def predict_proba(self, x, session=None, batch_size=500): |
| 149 |
|
"""Predict probability (Softmax) |
| 150 |
|
""" |
| 151 |
|
if session is None: |
| 152 |
|
if self.sess is None: |