Code Duplication    Length = 20-22 lines in 2 locations

pyActLearn/learning/nn/lstm.py 2 locations

@@ 421-442 (lines=22) @@
418
            saver.restore(session, os.path.join(summaries_dir, 'best.ckpt'))
419
        logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch)
420
421
    def predict_proba(self, x, session=None, writer=None, writer_id=None):
422
        """Predict probability (Softmax)
423
        """
424
        if session is None:
425
            if self.sess is None:
426
                session = tf.Session()
427
                self.sess = session
428
            else:
429
                session = self.sess
430
        targets = [self.y]
431
        if writer is not None:
432
            targets += [self.merged]
433
        results = session.run(targets,
434
                              feed_dict={self.x: x.reshape(tuple([1]) + x.shape),
435
                                         self.length: np.array([x.shape[0]], dtype=np.int),
436
                                         self.initial_state_c: np.zeros((1, self.num_units)),
437
                                         self.initial_state_h: np.zeros((1, self.num_units))})
438
        if writer is not None:
439
            writer.add_summary(results[1], writer_id)
440
        batch_y = results[0]
441
        # Get result
442
        return batch_y[0, :, :]
443
444
    def predict(self, x, session=None, writer=None, writer_id=None):
445
        if session is None:
@@ 444-463 (lines=20) @@
441
        # Get result
442
        return batch_y[0, :, :]
443
444
    def predict(self, x, session=None, writer=None, writer_id=None):
445
        if session is None:
446
            if self.sess is None:
447
                session = tf.Session()
448
                self.sess = session
449
            else:
450
                session = self.sess
451
        targets = [self.y_class]
452
        if writer is not None:
453
            targets += [self.merged]
454
        results = session.run(targets,
455
                              feed_dict={self.x: x.reshape(tuple([1]) + x.shape),
456
                                         self.length: np.array([x.shape[0]], dtype=np.int),
457
                                         self.initial_state_c: np.zeros((1, self.num_units)),
458
                                         self.initial_state_h: np.zeros((1, self.num_units))})
459
        if writer is not None:
460
            writer.add_summary(results[1], writer_id)
461
        batch_y = results[0]
462
        # Get result
463
        return batch_y[0, :]
464
465
    def predict_accuracy(self, x, y, session=None, writer=None, writer_id=None, with_loss=False):
466
        """Get Accuracy given feature array and corresponding labels