|
@@ 673-684 (lines=12) @@
|
| 670 |
|
feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
| 671 |
|
self.init_state: np.zeros(2*self.num_units)}) |
| 672 |
|
|
| 673 |
|
def predict_accuracy(self, x, y, session=None): |
| 674 |
|
"""Get Accuracy given feature array and corresponding labels |
| 675 |
|
""" |
| 676 |
|
if session is None: |
| 677 |
|
if self.sess is None: |
| 678 |
|
session = tf.Session() |
| 679 |
|
self.sess = session |
| 680 |
|
else: |
| 681 |
|
session = self.sess |
| 682 |
|
return session.run(self.accuracy, |
| 683 |
|
feed_dict={self.x: x, self.y_: y, self.length: x.shape[0] - self.num_skip, |
| 684 |
|
self.init_state: np.zeros(2*self.num_units)}) |
|
@@ 649-660 (lines=12) @@
|
| 646 |
|
if criterion == 'monitor_based': |
| 647 |
|
tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
| 648 |
|
|
| 649 |
|
def predict_proba(self, x, session=None, batch_size=500): |
| 650 |
|
"""Predict probability (Softmax) |
| 651 |
|
""" |
| 652 |
|
if session is None: |
| 653 |
|
if self.sess is None: |
| 654 |
|
session = tf.Session() |
| 655 |
|
self.sess = session |
| 656 |
|
else: |
| 657 |
|
session = self.sess |
| 658 |
|
return session.run(self.y, |
| 659 |
|
feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
| 660 |
|
self.init_state: np.zeros(2*self.num_units)}) |
| 661 |
|
|
| 662 |
|
def predict(self, x, session=None): |
| 663 |
|
if session is None: |
|
@@ 662-671 (lines=10) @@
|
| 659 |
|
feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
| 660 |
|
self.init_state: np.zeros(2*self.num_units)}) |
| 661 |
|
|
| 662 |
|
def predict(self, x, session=None): |
| 663 |
|
if session is None: |
| 664 |
|
if self.sess is None: |
| 665 |
|
session = tf.Session() |
| 666 |
|
self.sess = session |
| 667 |
|
else: |
| 668 |
|
session = self.sess |
| 669 |
|
return session.run(self.y_class, |
| 670 |
|
feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
| 671 |
|
self.init_state: np.zeros(2*self.num_units)}) |
| 672 |
|
|
| 673 |
|
def predict_accuracy(self, x, y, session=None): |
| 674 |
|
"""Get Accuracy given feature array and corresponding labels |