@@ 152-173 (lines=22) @@ | ||
149 | tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
|
150 | logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch) |
|
151 | ||
152 | def predict_proba(self, x, session=None, batch_size=500): |
|
153 | """Predict probability (Softmax) |
|
154 | """ |
|
155 | if session is None: |
|
156 | if self.sess is None: |
|
157 | session = tf.Session() |
|
158 | self.sess = session |
|
159 | else: |
|
160 | session = self.sess |
|
161 | injector = BatchSequenceInjector(batch_size=batch_size, data_x=x, seq_len=self.num_steps) |
|
162 | injector.reset() |
|
163 | result = None |
|
164 | while injector.num_epochs == 0: |
|
165 | batch_x = injector.next_batch() |
|
166 | batch_y = session.run(self.y, |
|
167 | feed_dict={self.x: batch_x, |
|
168 | self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
|
169 | if result is None: |
|
170 | result = batch_y |
|
171 | else: |
|
172 | result = np.concatenate((result, batch_y), axis=0) |
|
173 | return result |
|
174 | ||
175 | def predict(self, x, session=None, batch_size=500): |
|
176 | if session is None: |
|
@@ 175-194 (lines=20) @@ | ||
172 | result = np.concatenate((result, batch_y), axis=0) |
|
173 | return result |
|
174 | ||
175 | def predict(self, x, session=None, batch_size=500): |
|
176 | if session is None: |
|
177 | if self.sess is None: |
|
178 | session = tf.Session() |
|
179 | self.sess = session |
|
180 | else: |
|
181 | session = self.sess |
|
182 | injector = BatchSequenceInjector(batch_size=batch_size, data_x=x, seq_len=self.num_steps) |
|
183 | injector.reset() |
|
184 | result = None |
|
185 | while injector.num_epochs == 0: |
|
186 | batch_x = injector.next_batch() |
|
187 | batch_y = session.run(self.y_class, |
|
188 | feed_dict={self.x: batch_x, |
|
189 | self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
|
190 | if result is None: |
|
191 | result = batch_y |
|
192 | else: |
|
193 | result = np.concatenate((result, batch_y), axis=0) |
|
194 | return result |
|
195 | ||
196 | def predict_accuracy(self, x, y, session=None): |
|
197 | """Get Accuracy given feature array and corresponding labels |