@@ 107-128 (lines=22) @@ | ||
104 | self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
|
105 | print('Step %d, training accuracy %g, loss %g' % (i, accuracy, loss)) |
|
106 | ||
107 | def predict_proba(self, x, session=None, batch_size=500): |
|
108 | """Predict probability (Softmax) |
|
109 | """ |
|
110 | if session is None: |
|
111 | if self.sess is None: |
|
112 | session = tf.Session() |
|
113 | self.sess = session |
|
114 | else: |
|
115 | session = self.sess |
|
116 | injector = BatchSequenceInjector(batch_size=batch_size, data_x=x, seq_len=self.num_steps) |
|
117 | injector.reset() |
|
118 | result = None |
|
119 | while injector.num_epochs == 0: |
|
120 | batch_x = injector.next_batch() |
|
121 | batch_y = session.run(self.y, |
|
122 | feed_dict={self.x: batch_x, |
|
123 | self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
|
124 | if result is None: |
|
125 | result = batch_y |
|
126 | else: |
|
127 | result = np.concatenate((result, batch_y), axis=0) |
|
128 | return result |
|
129 | ||
130 | def predict(self, x, session=None, batch_size=500): |
|
131 | if session is None: |
|
@@ 130-149 (lines=20) @@ | ||
127 | result = np.concatenate((result, batch_y), axis=0) |
|
128 | return result |
|
129 | ||
130 | def predict(self, x, session=None, batch_size=500): |
|
131 | if session is None: |
|
132 | if self.sess is None: |
|
133 | session = tf.Session() |
|
134 | self.sess = session |
|
135 | else: |
|
136 | session = self.sess |
|
137 | injector = BatchSequenceInjector(batch_size=batch_size, data_x=x, seq_len=self.num_steps) |
|
138 | injector.reset() |
|
139 | result = None |
|
140 | while injector.num_epochs == 0: |
|
141 | batch_x = injector.next_batch() |
|
142 | batch_y = session.run(self.y_class, |
|
143 | feed_dict={self.x: batch_x, |
|
144 | self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
|
145 | if result is None: |
|
146 | result = batch_y |
|
147 | else: |
|
148 | result = np.concatenate((result, batch_y), axis=0) |
|
149 | return result |
|
150 |