| @@ 121-147 (lines=27) @@ | ||
| 118 | z = self._relu(T.dot(h_t, self.W_a) + self.B_a) |
|
| 119 | return self._softmax(z) |
|
| 120 | ||
| 121 | def _core_network(self, l_p, h_p, x_t): |
|
| 122 | """ |
|
| 123 | Parameters: |
|
| 124 | x_t - 28x28 image |
|
| 125 | l_p - 2x1 focus vector |
|
| 126 | h_p - 256x1 vector |
|
| 127 | Returns: |
|
| 128 | h_t, 256x1 vector |
|
| 129 | """ |
|
| 130 | g_t = self._glimpse_network(x_t, l_p) |
|
| 131 | h_t = self._tanh(T.dot(g_t, self.W_h_g) + T.dot(h_p, self.W_h) + self.B_h) |
|
| 132 | l_t = self._location_network(h_t) |
|
| 133 | ||
| 134 | if not self.disable_reinforce: |
|
| 135 | sampled_l_t = self._sample_gaussian(l_t, self.cov) |
|
| 136 | sampled_pdf = self._multi_gaussian_pdf(disconnected_grad(sampled_l_t), l_t) |
|
| 137 | wl_grad = T.grad(T.log(sampled_pdf), self.W_l) |
|
| 138 | else: |
|
| 139 | sampled_l_t = l_t |
|
| 140 | wl_grad = self.W_l |
|
| 141 | ||
| 142 | if self.random_glimpse and self.disable_reinforce: |
|
| 143 | sampled_l_t = self.srng.uniform((2,)) * 0.8 |
|
| 144 | ||
| 145 | a_t = self._action_network(h_t) |
|
| 146 | ||
| 147 | return sampled_l_t, h_t, a_t, wl_grad |
|
| 148 | ||
| 149 | ||
| 150 | def _output_func(self): |
|
| @@ 106-132 (lines=27) @@ | ||
| 103 | z = self._relu(T.dot(h_t, self.W_a) + self.B_a) |
|
| 104 | return self._softmax(z) |
|
| 105 | ||
| 106 | def _core_network(self, l_p, h_p, x_t): |
|
| 107 | """ |
|
| 108 | Parameters: |
|
| 109 | x_t - 28x28 image |
|
| 110 | l_p - 2x1 focus vector |
|
| 111 | h_p - 256x1 vector |
|
| 112 | Returns: |
|
| 113 | h_t, 256x1 vector |
|
| 114 | """ |
|
| 115 | g_t = self._glimpse_network(x_t, l_p) |
|
| 116 | h_t = self._tanh(T.dot(g_t, self.W_h_g) + T.dot(h_p, self.W_h) + self.B_h) |
|
| 117 | l_t = self._location_network(h_t) |
|
| 118 | ||
| 119 | if not self.disable_reinforce: |
|
| 120 | sampled_l_t = self._sample_gaussian(l_t, self.cov) |
|
| 121 | sampled_pdf = self._multi_gaussian_pdf(disconnected_grad(sampled_l_t), l_t) |
|
| 122 | wl_grad = T.grad(T.log(sampled_pdf), self.W_l) |
|
| 123 | else: |
|
| 124 | sampled_l_t = l_t |
|
| 125 | wl_grad = self.W_l |
|
| 126 | ||
| 127 | if self.random_glimpse and self.disable_reinforce: |
|
| 128 | sampled_l_t = self.srng.uniform((2,), low=-1.7, high=1.7) |
|
| 129 | ||
| 130 | a_t = self._action_network(h_t) |
|
| 131 | ||
| 132 | return sampled_l_t, h_t, a_t, wl_grad |
|
| 133 | ||
| 134 | ||
| 135 | def _output_func(self): |
|