|
@@ 152-173 (lines=22) @@
|
| 149 |
|
tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
| 150 |
|
logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch) |
| 151 |
|
|
| 152 |
|
def predict_proba(self, x, session=None, batch_size=500): |
| 153 |
|
"""Predict probability (Softmax) |
| 154 |
|
""" |
| 155 |
|
if session is None: |
| 156 |
|
if self.sess is None: |
| 157 |
|
session = tf.Session() |
| 158 |
|
self.sess = session |
| 159 |
|
else: |
| 160 |
|
session = self.sess |
| 161 |
|
injector = BatchSequenceInjector(batch_size=batch_size, data_x=x, seq_len=self.num_steps) |
| 162 |
|
injector.reset() |
| 163 |
|
result = None |
| 164 |
|
while injector.num_epochs == 0: |
| 165 |
|
batch_x = injector.next_batch() |
| 166 |
|
batch_y = session.run(self.y, |
| 167 |
|
feed_dict={self.x: batch_x, |
| 168 |
|
self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
| 169 |
|
if result is None: |
| 170 |
|
result = batch_y |
| 171 |
|
else: |
| 172 |
|
result = np.concatenate((result, batch_y), axis=0) |
| 173 |
|
return result |
| 174 |
|
|
| 175 |
|
def predict(self, x, session=None, batch_size=500): |
| 176 |
|
if session is None: |
|
@@ 175-194 (lines=20) @@
|
| 172 |
|
result = np.concatenate((result, batch_y), axis=0) |
| 173 |
|
return result |
| 174 |
|
|
| 175 |
|
def predict(self, x, session=None, batch_size=500): |
| 176 |
|
if session is None: |
| 177 |
|
if self.sess is None: |
| 178 |
|
session = tf.Session() |
| 179 |
|
self.sess = session |
| 180 |
|
else: |
| 181 |
|
session = self.sess |
| 182 |
|
injector = BatchSequenceInjector(batch_size=batch_size, data_x=x, seq_len=self.num_steps) |
| 183 |
|
injector.reset() |
| 184 |
|
result = None |
| 185 |
|
while injector.num_epochs == 0: |
| 186 |
|
batch_x = injector.next_batch() |
| 187 |
|
batch_y = session.run(self.y_class, |
| 188 |
|
feed_dict={self.x: batch_x, |
| 189 |
|
self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
| 190 |
|
if result is None: |
| 191 |
|
result = batch_y |
| 192 |
|
else: |
| 193 |
|
result = np.concatenate((result, batch_y), axis=0) |
| 194 |
|
return result |
| 195 |
|
|
| 196 |
|
def predict_accuracy(self, x, y, session=None): |
| 197 |
|
"""Get Accuracy given feature array and corresponding labels |