@@ 673-684 (lines=12) @@ | ||
670 | feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
|
671 | self.init_state: np.zeros(2*self.num_units)}) |
|
672 | ||
673 | def predict_accuracy(self, x, y, session=None): |
|
674 | """Get Accuracy given feature array and corresponding labels |
|
675 | """ |
|
676 | if session is None: |
|
677 | if self.sess is None: |
|
678 | session = tf.Session() |
|
679 | self.sess = session |
|
680 | else: |
|
681 | session = self.sess |
|
682 | return session.run(self.accuracy, |
|
683 | feed_dict={self.x: x, self.y_: y, self.length: x.shape[0] - self.num_skip, |
|
684 | self.init_state: np.zeros(2*self.num_units)}) |
|
@@ 649-660 (lines=12) @@ | ||
646 | if criterion == 'monitor_based': |
|
647 | tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
|
648 | ||
649 | def predict_proba(self, x, session=None, batch_size=500): |
|
650 | """Predict probability (Softmax) |
|
651 | """ |
|
652 | if session is None: |
|
653 | if self.sess is None: |
|
654 | session = tf.Session() |
|
655 | self.sess = session |
|
656 | else: |
|
657 | session = self.sess |
|
658 | return session.run(self.y, |
|
659 | feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
|
660 | self.init_state: np.zeros(2*self.num_units)}) |
|
661 | ||
662 | def predict(self, x, session=None): |
|
663 | if session is None: |
|
@@ 662-671 (lines=10) @@ | ||
659 | feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
|
660 | self.init_state: np.zeros(2*self.num_units)}) |
|
661 | ||
662 | def predict(self, x, session=None): |
|
663 | if session is None: |
|
664 | if self.sess is None: |
|
665 | session = tf.Session() |
|
666 | self.sess = session |
|
667 | else: |
|
668 | session = self.sess |
|
669 | return session.run(self.y_class, |
|
670 | feed_dict={self.x: x, self.length: x.shape[0] - self.num_skip, |
|
671 | self.init_state: np.zeros(2*self.num_units)}) |
|
672 | ||
673 | def predict_accuracy(self, x, y, session=None): |
|
674 | """Get Accuracy given feature array and corresponding labels |