| @@ 53-65 (lines=13) @@ | ||
| 50 | self.opt_func = optimize_function(self.network.weights + self.network.biases, self.config) |
|
| 51 | self.rl_opt_func = optimize_function([self.layer.W_l, self.layer.W_f], self.config) |
|
| 52 | ||
| 53 | def update_parameters(self, update_rl): |
|
| 54 | if not self.disable_backprop: |
|
| 55 | grads = [self.batch_grad[i] / self.batch_size for i in range(len(self.network.weights + self.network.biases))] |
|
| 56 | self.opt_func(*grads) |
|
| 57 | # REINFORCE update |
|
| 58 | if update_rl and not self.disable_reinforce: |
|
| 59 | if np.sum(self.batch_wl_grad) == 0 or np.sum(self.batch_wf_grad) == 0: |
|
| 60 | sys.stdout.write("0WRL ") |
|
| 61 | sys.stdout.flush() |
|
| 62 | else: |
|
| 63 | grad_wl = self.batch_wl_grad / self.batch_size |
|
| 64 | grad_wf = self.batch_wf_grad / self.batch_size |
|
| 65 | self.rl_opt_func(grad_wl, grad_wf) |
|
| 66 | ||
| 67 | def train_func(self, train_set): |
|
| 68 | cost_sum = 0.0 |
|
| @@ 50-61 (lines=12) @@ | ||
| 47 | # self.opt_interface = gradient_interface_future(self.network.weights + self.network.biases, config=self.config) |
|
| 48 | # self.l_opt_interface = gradient_interface_future([self.layer.W_l], config=self.config) |
|
| 49 | ||
| 50 | def update_parameters(self, update_wl): |
|
| 51 | if not self.disable_backprop: |
|
| 52 | grads = [self.batch_grad[i] / self.batch_size for i in range(len(self.network.weights + self.network.biases))] |
|
| 53 | self.opt_interface(*grads) |
|
| 54 | # REINFORCE update |
|
| 55 | if update_wl and not self.disable_reinforce: |
|
| 56 | if np.sum(self.batch_wl_grad) == 0: |
|
| 57 | sys.stdout.write("[0 WLG] ") |
|
| 58 | sys.stdout.flush() |
|
| 59 | else: |
|
| 60 | grad_wl = self.batch_wl_grad / self.batch_size |
|
| 61 | self.l_opt_interface(grad_wl) |
|
| 62 | ||
| 63 | def train_func(self, train_set): |
|
| 64 | cost_sum = 0.0 |
|