|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# -*- coding: utf-8 -*- |
|
3
|
|
|
import sys |
|
4
|
|
|
|
|
5
|
|
|
import numpy as np |
|
6
|
|
|
from numpy import linalg as LA |
|
7
|
|
|
from theano import tensor as T |
|
8
|
|
|
import theano |
|
9
|
|
|
|
|
10
|
|
|
from deepy.utils.functions import FLOATX |
|
11
|
|
|
from deepy.trainers import CustomizeTrainer |
|
12
|
|
|
from deepy.trainers.optimize import optimize_function |
|
13
|
|
|
|
|
14
|
|
|
|
|
15
|
|
|
class AttentionTrainer(CustomizeTrainer): |
|
16
|
|
|
|
|
17
|
|
|
def __init__(self, network, attention_layer, config): |
|
18
|
|
|
""" |
|
19
|
|
|
Parameters: |
|
20
|
|
|
network - AttentionNetwork |
|
21
|
|
|
config - training config |
|
22
|
|
|
:type network: NeuralClassifier |
|
23
|
|
|
:type attention_layer: experiments.attention_models.baseline_model.AttentionLayer |
|
24
|
|
|
:type config: TrainerConfig |
|
25
|
|
|
""" |
|
26
|
|
|
super(AttentionTrainer, self).__init__(network, config) |
|
27
|
|
|
self.large_cov_mode = False |
|
28
|
|
|
self.batch_size = config.get("batch_size", 20) |
|
29
|
|
|
self.disable_backprop = config.get("disable_backprop", False) |
|
30
|
|
|
self.disable_reinforce = config.get("disable_reinforce", False) |
|
31
|
|
|
self.last_average_reward = 999 |
|
32
|
|
|
self.turn = 1 |
|
33
|
|
|
self.layer = attention_layer |
|
34
|
|
|
if self.disable_backprop: |
|
35
|
|
|
grads = [] |
|
36
|
|
|
else: |
|
37
|
|
|
grads = [T.grad(self.cost, p) for p in network.weights + network.biases] |
|
38
|
|
|
if self.disable_reinforce: |
|
39
|
|
|
grad_l = self.layer.W_l |
|
40
|
|
|
else: |
|
41
|
|
|
grad_l = self.layer.wl_grad |
|
42
|
|
|
self.batch_wl_grad = np.zeros(attention_layer.W_l.get_value().shape, dtype=FLOATX) |
|
43
|
|
|
self.batch_grad = [np.zeros(p.get_value().shape, dtype=FLOATX) for p in network.weights + network.biases] |
|
44
|
|
|
self.grad_func = theano.function(network.inputs, [self.cost, grad_l, attention_layer.positions, attention_layer.last_decision] + grads, allow_input_downcast=True) |
|
45
|
|
|
self.opt_interface = optimize_function(self.network.weights + self.network.biases, self.config) |
|
46
|
|
|
self.l_opt_interface = optimize_function([self.layer.W_l], self.config) |
|
47
|
|
|
# self.opt_interface = gradient_interface_future(self.network.weights + self.network.biases, config=self.config) |
|
48
|
|
|
# self.l_opt_interface = gradient_interface_future([self.layer.W_l], config=self.config) |
|
49
|
|
|
|
|
50
|
|
|
def update_parameters(self, update_wl): |
|
51
|
|
|
if not self.disable_backprop: |
|
52
|
|
|
grads = [self.batch_grad[i] / self.batch_size for i in range(len(self.network.weights + self.network.biases))] |
|
53
|
|
|
self.opt_interface(*grads) |
|
54
|
|
|
# REINFORCE update |
|
55
|
|
|
if update_wl and not self.disable_reinforce: |
|
56
|
|
|
if np.sum(self.batch_wl_grad) == 0: |
|
57
|
|
|
sys.stdout.write("[0 WLG] ") |
|
58
|
|
|
sys.stdout.flush() |
|
59
|
|
|
else: |
|
60
|
|
|
grad_wl = self.batch_wl_grad / self.batch_size |
|
61
|
|
|
self.l_opt_interface(grad_wl) |
|
62
|
|
|
|
|
63
|
|
|
def train_func(self, train_set): |
|
64
|
|
|
cost_sum = 0.0 |
|
65
|
|
|
batch_cost = 0.0 |
|
66
|
|
|
counter = 0 |
|
67
|
|
|
total = 0 |
|
68
|
|
|
total_reward = 0 |
|
69
|
|
|
batch_reward = 0 |
|
70
|
|
|
total_position_value = 0 |
|
71
|
|
|
pena_count = 0 |
|
72
|
|
|
for d in train_set: |
|
73
|
|
|
pairs = self.grad_func(*d) |
|
74
|
|
|
cost = pairs[0] |
|
75
|
|
|
if cost > 10 or np.isnan(cost): |
|
76
|
|
|
sys.stdout.write("X") |
|
77
|
|
|
sys.stdout.flush() |
|
78
|
|
|
continue |
|
79
|
|
|
batch_cost += cost |
|
80
|
|
|
|
|
81
|
|
|
wl_grad = pairs[1] |
|
82
|
|
|
max_position_value = np.max(np.absolute(pairs[2])) |
|
83
|
|
|
total_position_value += max_position_value |
|
84
|
|
|
last_decision = pairs[3] |
|
85
|
|
|
target_decision = d[1][0] |
|
86
|
|
|
reward = 0.005 if last_decision == target_decision else 0 |
|
87
|
|
|
if max_position_value > 0.8: |
|
88
|
|
|
reward = 0 |
|
89
|
|
|
total_reward += reward |
|
90
|
|
|
batch_reward += reward |
|
91
|
|
|
if self.last_average_reward == 999 and total > 2000: |
|
92
|
|
|
self.last_average_reward = total_reward / total |
|
93
|
|
|
if not self.disable_reinforce: |
|
94
|
|
|
self.batch_wl_grad += wl_grad * - (reward - self.last_average_reward) |
|
95
|
|
|
if not self.disable_backprop: |
|
96
|
|
|
for grad_cache, grad in zip(self.batch_grad, pairs[4:]): |
|
97
|
|
|
grad_cache += grad |
|
98
|
|
|
counter += 1 |
|
99
|
|
|
total += 1 |
|
100
|
|
|
if counter >= self.batch_size: |
|
101
|
|
|
if total == counter: counter -= 1 |
|
102
|
|
|
self.update_parameters(self.last_average_reward < 999) |
|
103
|
|
|
|
|
104
|
|
|
# Clean batch gradients |
|
105
|
|
|
if not self.disable_reinforce: |
|
106
|
|
|
self.batch_wl_grad *= 0 |
|
107
|
|
|
if not self.disable_backprop: |
|
108
|
|
|
for grad_cache in self.batch_grad: |
|
109
|
|
|
grad_cache *= 0 |
|
110
|
|
|
|
|
111
|
|
|
if total % 1000 == 0: |
|
112
|
|
|
sys.stdout.write(".") |
|
113
|
|
|
sys.stdout.flush() |
|
114
|
|
|
|
|
115
|
|
|
# Cov |
|
116
|
|
|
if not self.disable_reinforce: |
|
|
|
|
|
|
117
|
|
|
cov_changed = False |
|
118
|
|
|
if batch_reward / self.batch_size < 0.001: |
|
119
|
|
|
if not self.large_cov_mode: |
|
120
|
|
|
if pena_count > 20: |
|
121
|
|
|
self.layer.cov.set_value(self.layer.large_cov) |
|
122
|
|
|
print "[LCOV]", |
|
123
|
|
|
cov_changed = True |
|
124
|
|
|
else: |
|
125
|
|
|
pena_count += 1 |
|
126
|
|
|
else: |
|
127
|
|
|
pena_count = 0 |
|
128
|
|
|
else: |
|
129
|
|
|
if self.large_cov_mode: |
|
130
|
|
|
if pena_count > 20: |
|
131
|
|
|
self.layer.cov.set_value(self.layer.small_cov) |
|
132
|
|
|
print "[SCOV]", |
|
133
|
|
|
cov_changed = True |
|
134
|
|
|
else: |
|
135
|
|
|
pena_count += 1 |
|
136
|
|
|
else: |
|
137
|
|
|
pena_count = 0 |
|
138
|
|
|
if cov_changed: |
|
139
|
|
|
self.large_cov_mode = not self.large_cov_mode |
|
140
|
|
|
self.layer.cov_inv_var.set_value(np.array(LA.inv(self.layer.cov.get_value()), dtype=FLOATX)) |
|
141
|
|
|
self.layer.cov_det_var.set_value(LA.det(self.layer.cov.get_value())) |
|
142
|
|
|
|
|
143
|
|
|
# Clean batch cost |
|
144
|
|
|
counter = 0 |
|
145
|
|
|
cost_sum += batch_cost |
|
146
|
|
|
batch_cost = 0.0 |
|
147
|
|
|
batch_reward = 0 |
|
148
|
|
|
if total == 0: |
|
149
|
|
|
return "COST OVERFLOW" |
|
150
|
|
|
|
|
151
|
|
|
sys.stdout.write("\n") |
|
152
|
|
|
self.last_average_reward = (total_reward / total) |
|
153
|
|
|
self.turn += 1 |
|
154
|
|
|
return "J: %.2f, Avg R: %.4f, Avg P: %.2f" % ((cost_sum / total), self.last_average_reward, (total_position_value / total)) |
|
155
|
|
|
|
|
156
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.