| @@ 157-189 (lines=33) @@ | ||
| 154 | self._softmax = build_activation("softmax") |
|
| 155 | self.output_func = self._output_func() |
|
| 156 | ||
| 157 | def _setup_params(self): |
|
| 158 | self.srng = RandomStreams(seed=234) |
|
| 159 | self.large_cov = np.array([[0.06,0],[0,0.06]], dtype=FLOATX) |
|
| 160 | self.small_cov = np.array([[self.gaussian_std,0],[0,self.gaussian_std]], dtype=FLOATX) |
|
| 161 | self.cov = theano.shared(np.array(self.small_cov, dtype=FLOATX)) |
|
| 162 | self.cov_inv_var = theano.shared(np.array(LA.inv(self.small_cov), dtype=FLOATX)) |
|
| 163 | self.cov_det_var = theano.shared(np.array(LA.det(self.small_cov), dtype=FLOATX)) |
|
| 164 | self._sample_gaussian = SampleMultivariateGaussian() |
|
| 165 | ||
| 166 | self.W_g0 = self.create_weight(7*7, 128, suffix="g0") |
|
| 167 | self.W_g1 = self.create_weight(2, 128, suffix="g1") |
|
| 168 | self.W_g2_hg = self.create_weight(128, 256, suffix="g2_hg") |
|
| 169 | self.W_g2_hl = self.create_weight(128, 256, suffix="g2_hl") |
|
| 170 | ||
| 171 | self.W_h_g = self.create_weight(256, 256, suffix="h_g") |
|
| 172 | self.W_h = self.create_weight(256, 256, suffix="h") |
|
| 173 | self.B_h = self.create_bias(256, suffix="h") |
|
| 174 | self.h0 = self.create_vector(256, "h0") |
|
| 175 | self.l0 = self.create_vector(2, "l0") |
|
| 176 | self.l0.set_value(np.array([-1, -1], dtype=FLOATX)) |
|
| 177 | ||
| 178 | self.W_l = self.create_weight(256, 2, suffix="l") |
|
| 179 | self.W_l.set_value(self.W_l.get_value() / 10) |
|
| 180 | self.B_l = self.create_bias(2, suffix="l") |
|
| 181 | self.W_a = self.create_weight(256, 10, suffix="a") |
|
| 182 | self.B_a = self.create_bias(10, suffix="a") |
|
| 183 | ||
| 184 | self.W_f = self.create_weight(7*7, 2, suffix="f") |
|
| 185 | ||
| 186 | ||
| 187 | self.W = [self.W_g0, self.W_g1, self.W_g2_hg, self.W_g2_hl, self.W_h_g, self.W_h, self.W_a] |
|
| 188 | self.B = [self.B_h, self.B_a] |
|
| 189 | self.parameters = [self.W_l, self.W_f] |
|
| 190 | ||
| 191 | def get_network(model=None, std=0.005, disable_reinforce=False, random_glimpse=False): |
|
| 192 | """ |
|
| @@ 170-200 (lines=31) @@ | ||
| 167 | self._softmax = build_activation("softmax") |
|
| 168 | self.output_func = self._output_func() |
|
| 169 | ||
| 170 | def _setup_params(self): |
|
| 171 | self.srng = RandomStreams(seed=234) |
|
| 172 | self.large_cov = np.array([[0.06,0],[0,0.06]], dtype=FLOATX) |
|
| 173 | self.small_cov = np.array([[self.gaussian_std,0],[0,self.gaussian_std]], dtype=FLOATX) |
|
| 174 | self.cov = theano.shared(np.array(self.small_cov, dtype=FLOATX)) |
|
| 175 | self.cov_inv_var = theano.shared(np.array(LA.inv(self.small_cov), dtype=FLOATX)) |
|
| 176 | self.cov_det_var = theano.shared(np.array(LA.det(self.small_cov), dtype=FLOATX)) |
|
| 177 | self._sample_gaussian = SampleMultivariateGaussian() |
|
| 178 | ||
| 179 | self.W_g0 = self.create_weight(7*7, 128, suffix="g0") |
|
| 180 | self.W_g1 = self.create_weight(2, 128, suffix="g1") |
|
| 181 | self.W_g2_hg = self.create_weight(128, 256, suffix="g2_hg") |
|
| 182 | self.W_g2_hl = self.create_weight(128, 256, suffix="g2_hl") |
|
| 183 | ||
| 184 | self.W_h_g = self.create_weight(256, 256, suffix="h_g") |
|
| 185 | self.W_h = self.create_weight(256, 256, suffix="h") |
|
| 186 | self.B_h = self.create_bias(256, suffix="h") |
|
| 187 | self.h0 = self.create_vector(256, "h0") |
|
| 188 | self.l0 = self.create_vector(2, "l0") |
|
| 189 | self.l0.set_value(np.array([-1, -1], dtype=FLOATX)) |
|
| 190 | ||
| 191 | self.W_l = self.create_weight(256, 2, suffix="l") |
|
| 192 | self.W_l.set_value(self.W_l.get_value() / 10) |
|
| 193 | self.B_l = self.create_bias(2, suffix="l") |
|
| 194 | self.W_a = self.create_weight(256, 10, suffix="a") |
|
| 195 | self.B_a = self.create_bias(10, suffix="a") |
|
| 196 | ||
| 197 | ||
| 198 | self.W = [self.W_g0, self.W_g1, self.W_g2_hg, self.W_g2_hl, self.W_h_g, self.W_h, self.W_a] |
|
| 199 | self.B = [self.B_h, self.B_a] |
|
| 200 | self.parameters = [self.W_l] |
|
| 201 | ||
| 202 | ||
| 203 | def get_network(model=None, std=0.005, disable_reinforce=False, random_glimpse=False): |
|