@@ 421-442 (lines=22) @@ | ||
418 | saver.restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
|
419 | logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch) |
|
420 | ||
421 | def predict_proba(self, x, session=None, writer=None, writer_id=None): |
|
422 | """Predict probability (Softmax) |
|
423 | """ |
|
424 | if session is None: |
|
425 | if self.sess is None: |
|
426 | session = tf.Session() |
|
427 | self.sess = session |
|
428 | else: |
|
429 | session = self.sess |
|
430 | targets = [self.y] |
|
431 | if writer is not None: |
|
432 | targets += [self.merged] |
|
433 | results = session.run(targets, |
|
434 | feed_dict={self.x: x.reshape(tuple([1]) + x.shape), |
|
435 | self.length: np.array([x.shape[0]], dtype=np.int), |
|
436 | self.initial_state_c: np.zeros((1, self.num_units)), |
|
437 | self.initial_state_h: np.zeros((1, self.num_units))}) |
|
438 | if writer is not None: |
|
439 | writer.add_summary(results[1], writer_id) |
|
440 | batch_y = results[0] |
|
441 | # Get result |
|
442 | return batch_y[0, :, :] |
|
443 | ||
444 | def predict(self, x, session=None, writer=None, writer_id=None): |
|
445 | if session is None: |
|
@@ 444-463 (lines=20) @@ | ||
441 | # Get result |
|
442 | return batch_y[0, :, :] |
|
443 | ||
444 | def predict(self, x, session=None, writer=None, writer_id=None): |
|
445 | if session is None: |
|
446 | if self.sess is None: |
|
447 | session = tf.Session() |
|
448 | self.sess = session |
|
449 | else: |
|
450 | session = self.sess |
|
451 | targets = [self.y_class] |
|
452 | if writer is not None: |
|
453 | targets += [self.merged] |
|
454 | results = session.run(targets, |
|
455 | feed_dict={self.x: x.reshape(tuple([1]) + x.shape), |
|
456 | self.length: np.array([x.shape[0]], dtype=np.int), |
|
457 | self.initial_state_c: np.zeros((1, self.num_units)), |
|
458 | self.initial_state_h: np.zeros((1, self.num_units))}) |
|
459 | if writer is not None: |
|
460 | writer.add_summary(results[1], writer_id) |
|
461 | batch_y = results[0] |
|
462 | # Get result |
|
463 | return batch_y[0, :] |
|
464 | ||
465 | def predict_accuracy(self, x, y, session=None, writer=None, writer_id=None, with_loss=False): |
|
466 | """Get Accuracy given feature array and corresponding labels |