1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
from . import NeuralLayer |
5
|
|
|
from neural_var import NeuralVariable |
6
|
|
|
from deepy.utils import build_activation, FLOATX |
7
|
|
|
import numpy as np |
8
|
|
|
import theano |
9
|
|
|
import theano.tensor as T |
10
|
|
|
from collections import OrderedDict |
11
|
|
|
|
12
|
|
|
OUTPUT_TYPES = ["sequence", "one"] |
13
|
|
|
INPUT_TYPES = ["sequence", "one"] |
14
|
|
|
|
15
|
|
|
class LSTM(NeuralLayer): |
16
|
|
|
""" |
17
|
|
|
Long short-term memory layer. |
18
|
|
|
""" |
19
|
|
|
|
20
|
|
|
def __init__(self, hidden_size, input_type="sequence", output_type="sequence", |
21
|
|
|
inner_activation="sigmoid", outer_activation="tanh", |
22
|
|
|
inner_init=None, outer_init=None, steps=None, |
23
|
|
|
go_backwards=False, |
24
|
|
|
persistent_state=False, batch_size=0, |
25
|
|
|
reset_state_for_input=None, forget_bias=1, |
26
|
|
|
mask=None, |
27
|
|
|
second_input=None, second_input_size=None): |
28
|
|
|
super(LSTM, self).__init__("lstm") |
29
|
|
|
self.main_state = 'state' |
30
|
|
|
self._hidden_size = hidden_size |
31
|
|
|
self._input_type = input_type |
32
|
|
|
self._output_type = output_type |
33
|
|
|
self._inner_activation = inner_activation |
34
|
|
|
self._outer_activation = outer_activation |
35
|
|
|
self._inner_init = inner_init |
36
|
|
|
self._outer_init = outer_init |
37
|
|
|
self._steps = steps |
38
|
|
|
self.persistent_state = persistent_state |
39
|
|
|
self.reset_state_for_input = reset_state_for_input |
40
|
|
|
self.batch_size = batch_size |
41
|
|
|
self.go_backwards = go_backwards |
42
|
|
|
# mask |
43
|
|
|
mask = mask.tensor if type(mask) == NeuralVariable else mask |
44
|
|
|
self.mask = mask.dimshuffle((1,0)) if mask else None |
45
|
|
|
self._sequence_map = OrderedDict() |
46
|
|
|
# second input |
47
|
|
|
if type(second_input) == NeuralVariable: |
48
|
|
|
second_input_size = second_input.dim() |
49
|
|
|
second_input = second_input.tensor |
50
|
|
|
|
51
|
|
|
self.second_input = second_input |
52
|
|
|
self.second_input_size = second_input_size |
53
|
|
|
self.forget_bias = forget_bias |
54
|
|
|
if input_type not in INPUT_TYPES: |
55
|
|
|
raise Exception("Input type of LSTM is wrong: %s" % input_type) |
56
|
|
|
if output_type not in OUTPUT_TYPES: |
57
|
|
|
raise Exception("Output type of LSTM is wrong: %s" % output_type) |
58
|
|
|
if self.persistent_state and not self.batch_size: |
59
|
|
|
raise Exception("Batch size must be set for persistent state mode") |
60
|
|
|
if mask and input_type == "one": |
61
|
|
|
raise Exception("Mask only works with sequence input") |
62
|
|
|
|
63
|
|
|
def _auto_reset_memories(self, x, h, m): |
64
|
|
|
reset_matrix = T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x') |
65
|
|
|
h = h * reset_matrix |
66
|
|
|
m = m * reset_matrix |
67
|
|
|
return h, m |
68
|
|
|
|
69
|
|
|
def step(self, *vars): |
70
|
|
|
# Parse sequence |
71
|
|
|
sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)])) |
72
|
|
|
h_tm1, c_tm1 = vars[-2:] |
73
|
|
|
# Reset state |
74
|
|
|
if self.reset_state_for_input != None: |
75
|
|
|
h_tm1, c_tm1 = self._auto_reset_memories(sequence_map["x"], h_tm1, c_tm1) |
76
|
|
|
|
77
|
|
|
if self._input_type == "sequence": |
78
|
|
|
xi_t, xf_t, xo_t, xc_t = map(sequence_map.get, ["xi", "xf", "xo", "xc"]) |
79
|
|
|
else: |
80
|
|
|
xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0 |
81
|
|
|
|
82
|
|
|
# Add second input |
83
|
|
|
if "xi2" in sequence_map: |
84
|
|
|
xi2, xf2, xo2, xc2 = map(sequence_map.get, ["xi2", "xf2", "xo2", "xc2"]) |
85
|
|
|
xi_t += xi2 |
86
|
|
|
xf_t += xf2 |
87
|
|
|
xo_t += xo2 |
88
|
|
|
xc_t += xc2 |
89
|
|
|
# LSTM core step |
90
|
|
|
i_t = self._inner_act(xi_t + T.dot(h_tm1, self.U_i) + self.b_i) |
91
|
|
|
f_t = self._inner_act(xf_t + T.dot(h_tm1, self.U_f) + self.b_f) |
92
|
|
|
c_t = f_t * c_tm1 + i_t * self._outer_act(xc_t + T.dot(h_tm1, self.U_c) + self.b_c) |
93
|
|
|
o_t = self._inner_act(xo_t + T.dot(h_tm1, self.U_o) + self.b_o) |
94
|
|
|
h_t = o_t * self._outer_act(c_t) |
95
|
|
|
# Apply mask |
96
|
|
|
if "mask" in sequence_map: |
97
|
|
|
mask = sequence_map["mask"].dimshuffle(0, 'x') |
98
|
|
|
h_t = h_t * mask + h_tm1 * (1 - mask) |
99
|
|
|
c_t = c_t * mask + c_tm1 * (1 - mask) |
100
|
|
|
return h_t, c_t |
101
|
|
|
|
102
|
|
|
def get_step_inputs(self, x, mask=None, second_input=None): |
103
|
|
|
# Create sequence map |
104
|
|
|
self._sequence_map.clear() |
105
|
|
|
if self._input_type == "sequence": |
106
|
|
|
# Input vars |
107
|
|
|
xi = T.dot(x, self.W_i) |
108
|
|
|
xf = T.dot(x, self.W_f) |
109
|
|
|
xc = T.dot(x, self.W_c) |
110
|
|
|
xo = T.dot(x, self.W_o) |
111
|
|
|
self._sequence_map.update([("xi", xi), ("xf", xf), ("xc", xc), ("xo", xo)]) |
112
|
|
|
# Reset state |
113
|
|
|
if self.reset_state_for_input != None: |
114
|
|
|
self._sequence_map["x"] = x |
115
|
|
|
# Add mask |
116
|
|
|
if mask: |
117
|
|
|
self._sequence_map["mask"] = mask |
118
|
|
|
elif self.mask: |
119
|
|
|
self._sequence_map["mask"] = self.mask |
120
|
|
|
# Add second input |
121
|
|
|
if self.second_input and not second_input: |
122
|
|
|
second_input = self.second_input |
123
|
|
|
if second_input: |
124
|
|
|
xi2 = T.dot(second_input, self.W_i2) |
125
|
|
|
xf2 = T.dot(second_input, self.W_f2) |
126
|
|
|
xc2 = T.dot(second_input, self.W_c2) |
127
|
|
|
xo2 = T.dot(second_input, self.W_o2) |
128
|
|
|
self._sequence_map.update([("xi2", xi2), ("xf2", xf2), ("xc2", xc2), ("xo2", xo2)]) |
129
|
|
|
return self._sequence_map.values() |
130
|
|
|
|
131
|
|
|
def get_initial_states(self, x): |
132
|
|
|
if self.persistent_state: |
133
|
|
|
return self.state_h, self.state_m |
134
|
|
|
else: |
135
|
|
|
h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size) |
136
|
|
|
m0 = h0 |
137
|
|
|
return {'state': h0, 'c': m0} |
138
|
|
|
|
139
|
|
|
def compute_tensor(self, x): |
140
|
|
|
h0, m0 = self.get_initial_states(x) |
141
|
|
|
if self._input_type == "sequence": |
142
|
|
|
# Move middle dimension to left-most position |
143
|
|
|
# (sequence, batch, value) |
144
|
|
|
x = x.dimshuffle((1,0,2)) |
145
|
|
|
sequences = self.produce_input_sequences(x) |
146
|
|
|
else: |
147
|
|
|
h0 = x |
148
|
|
|
sequences = self.produce_input_sequences(None) |
149
|
|
|
|
150
|
|
|
[hiddens, memories], _ = theano.scan( |
151
|
|
|
self.step, |
152
|
|
|
|
153
|
|
|
sequences=sequences, |
154
|
|
|
outputs_info=[h0, m0], |
155
|
|
|
n_steps=self._steps, |
156
|
|
|
go_backwards=self.go_backwards |
157
|
|
|
) |
158
|
|
|
|
159
|
|
|
# Save persistent state |
160
|
|
|
if self.persistent_state: |
161
|
|
|
self.register_updates((self.state_h, hiddens[-1])) |
162
|
|
|
self.register_updates((self.state_m, memories[-1])) |
163
|
|
|
|
164
|
|
|
if self._output_type == "one": |
165
|
|
|
return hiddens[-1] |
166
|
|
|
elif self._output_type == "sequence": |
167
|
|
|
return hiddens.dimshuffle((1,0,2)) |
168
|
|
|
|
169
|
|
|
|
170
|
|
|
def prepare(self): |
171
|
|
|
self._setup_params() |
172
|
|
|
self._setup_functions() |
173
|
|
|
|
174
|
|
|
def _setup_functions(self): |
175
|
|
|
self._inner_act = build_activation(self._inner_activation) |
176
|
|
|
self._outer_act = build_activation(self._outer_activation) |
177
|
|
|
|
178
|
|
|
def _setup_params(self): |
179
|
|
|
self.output_dim = self._hidden_size |
180
|
|
|
|
181
|
|
|
self.W_i = self.create_weight(self.input_dim, self._hidden_size, "wi", initializer=self._outer_init) |
182
|
|
|
self.U_i = self.create_weight(self._hidden_size, self._hidden_size, "ui", initializer=self._inner_init) |
183
|
|
|
self.b_i = self.create_bias(self._hidden_size, "i") |
184
|
|
|
|
185
|
|
|
self.W_f = self.create_weight(self.input_dim, self._hidden_size, "wf", initializer=self._outer_init) |
186
|
|
|
self.U_f = self.create_weight(self._hidden_size, self._hidden_size, "uf", initializer=self._inner_init) |
187
|
|
|
self.b_f = self.create_bias(self._hidden_size, "f") |
188
|
|
|
if self.forget_bias > 0: |
189
|
|
|
self.b_f.set_value(np.ones((self._hidden_size,), dtype=FLOATX)) |
190
|
|
|
|
191
|
|
|
self.W_c = self.create_weight(self.input_dim, self._hidden_size, "wc", initializer=self._outer_init) |
192
|
|
|
self.U_c = self.create_weight(self._hidden_size, self._hidden_size, "uc", initializer=self._inner_init) |
193
|
|
|
self.b_c = self.create_bias(self._hidden_size, "c") |
194
|
|
|
|
195
|
|
|
self.W_o = self.create_weight(self.input_dim, self._hidden_size, "wo", initializer=self._outer_init) |
196
|
|
|
self.U_o = self.create_weight(self._hidden_size, self._hidden_size, "uo", initializer=self._inner_init) |
197
|
|
|
self.b_o = self.create_bias(self._hidden_size, suffix="o") |
198
|
|
|
|
199
|
|
|
|
200
|
|
|
if self._input_type == "sequence": |
201
|
|
|
self.register_parameters(self.W_i, self.U_i, self.b_i, |
202
|
|
|
self.W_c, self.U_c, self.b_c, |
203
|
|
|
self.W_f, self.U_f, self.b_f, |
204
|
|
|
self.W_o, self.U_o, self.b_o) |
205
|
|
|
else: |
206
|
|
|
self.register_parameters(self.U_i, self.b_i, |
207
|
|
|
self.U_c, self.b_c, |
208
|
|
|
self.U_f, self.b_f, |
209
|
|
|
self.U_o, self.b_o) |
210
|
|
|
# Second input |
211
|
|
|
if self.second_input_size: |
212
|
|
|
self.W_i2 = self.create_weight(self.second_input_size, self._hidden_size, "wi2", initializer=self._outer_init) |
213
|
|
|
self.W_f2 = self.create_weight(self.second_input_size, self._hidden_size, "wf2", initializer=self._outer_init) |
214
|
|
|
self.W_c2 = self.create_weight(self.second_input_size, self._hidden_size, "wc2", initializer=self._outer_init) |
215
|
|
|
self.W_o2 = self.create_weight(self.second_input_size, self._hidden_size, "wo2", initializer=self._outer_init) |
216
|
|
|
self.register_parameters(self.W_i2, self.W_f2, self.W_c2, self.W_o2) |
217
|
|
|
|
218
|
|
|
# Create persistent state |
219
|
|
|
if self.persistent_state: |
220
|
|
|
self.state_h = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_h") |
221
|
|
|
self.state_m = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_m") |
222
|
|
|
self.register_free_parameters(self.state_h, self.state_m) |
223
|
|
|
else: |
224
|
|
|
self.state_h = None |
225
|
|
|
self.state_m = None |
226
|
|
|
|