|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# -*- coding: utf-8 -*- |
|
3
|
|
|
|
|
4
|
|
|
from . import NeuralLayer |
|
5
|
|
|
from var import NeuralVariable |
|
6
|
|
|
from deepy.utils import build_activation, FLOATX |
|
7
|
|
|
import numpy as np |
|
8
|
|
|
import theano |
|
9
|
|
|
import theano.tensor as T |
|
10
|
|
|
from collections import OrderedDict |
|
11
|
|
|
|
|
12
|
|
|
OUTPUT_TYPES = ["sequence", "one"] |
|
13
|
|
|
INPUT_TYPES = ["sequence", "one"] |
|
14
|
|
|
|
|
15
|
|
|
class LSTM(NeuralLayer): |
|
16
|
|
|
""" |
|
17
|
|
|
Long short-term memory layer. |
|
18
|
|
|
""" |
|
19
|
|
|
|
|
20
|
|
|
def __init__(self, hidden_size, input_type="sequence", output_type="sequence", |
|
21
|
|
|
inner_activation="sigmoid", outer_activation="tanh", |
|
22
|
|
|
inner_init=None, outer_init=None, steps=None, |
|
23
|
|
|
go_backwards=False, |
|
24
|
|
|
persistent_state=False, batch_size=0, |
|
25
|
|
|
reset_state_for_input=None, forget_bias=1, |
|
26
|
|
|
mask=None, |
|
27
|
|
|
second_input=None, second_input_size=None): |
|
28
|
|
|
super(LSTM, self).__init__("lstm") |
|
29
|
|
|
self._hidden_size = hidden_size |
|
30
|
|
|
self._input_type = input_type |
|
31
|
|
|
self._output_type = output_type |
|
32
|
|
|
self._inner_activation = inner_activation |
|
33
|
|
|
self._outer_activation = outer_activation |
|
34
|
|
|
self._inner_init = inner_init |
|
35
|
|
|
self._outer_init = outer_init |
|
36
|
|
|
self._steps = steps |
|
37
|
|
|
self.persistent_state = persistent_state |
|
38
|
|
|
self.reset_state_for_input = reset_state_for_input |
|
39
|
|
|
self.batch_size = batch_size |
|
40
|
|
|
self.go_backwards = go_backwards |
|
41
|
|
|
# mask |
|
42
|
|
|
mask = mask.tensor if type(mask) == NeuralVariable else mask |
|
43
|
|
|
self.mask = mask.dimshuffle((1,0)) if mask else None |
|
44
|
|
|
self._sequence_map = OrderedDict() |
|
45
|
|
|
# second input |
|
46
|
|
|
if type(second_input) == NeuralVariable: |
|
47
|
|
|
second_input_size = second_input.dim() |
|
48
|
|
|
second_input = second_input.tensor |
|
49
|
|
|
|
|
50
|
|
|
self.second_input = second_input |
|
51
|
|
|
self.second_input_size = second_input_size |
|
52
|
|
|
self.forget_bias = forget_bias |
|
53
|
|
|
if input_type not in INPUT_TYPES: |
|
54
|
|
|
raise Exception("Input type of LSTM is wrong: %s" % input_type) |
|
55
|
|
|
if output_type not in OUTPUT_TYPES: |
|
56
|
|
|
raise Exception("Output type of LSTM is wrong: %s" % output_type) |
|
57
|
|
|
if self.persistent_state and not self.batch_size: |
|
58
|
|
|
raise Exception("Batch size must be set for persistent state mode") |
|
59
|
|
|
if mask and input_type == "one": |
|
60
|
|
|
raise Exception("Mask only works with sequence input") |
|
61
|
|
|
|
|
62
|
|
|
def _auto_reset_memories(self, x, h, m): |
|
63
|
|
|
reset_matrix = T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x') |
|
64
|
|
|
h = h * reset_matrix |
|
65
|
|
|
m = m * reset_matrix |
|
66
|
|
|
return h, m |
|
67
|
|
|
|
|
68
|
|
|
def step(self, *vars): |
|
69
|
|
|
# Parse sequence |
|
70
|
|
|
sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)])) |
|
71
|
|
|
h_tm1, c_tm1 = vars[-2:] |
|
72
|
|
|
# Reset state |
|
73
|
|
|
if self.reset_state_for_input != None: |
|
74
|
|
|
h_tm1, c_tm1 = self._auto_reset_memories(sequence_map["x"], h_tm1, c_tm1) |
|
75
|
|
|
|
|
76
|
|
|
if self._input_type == "sequence": |
|
77
|
|
|
xi_t, xf_t, xo_t, xc_t = map(sequence_map.get, ["xi", "xf", "xo", "xc"]) |
|
78
|
|
|
else: |
|
79
|
|
|
xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0 |
|
80
|
|
|
|
|
81
|
|
|
# Add second input |
|
82
|
|
|
if "xi2" in sequence_map: |
|
83
|
|
|
xi2, xf2, xo2, xc2 = map(sequence_map.get, ["xi2", "xf2", "xo2", "xc2"]) |
|
84
|
|
|
xi_t += xi2 |
|
85
|
|
|
xf_t += xf2 |
|
86
|
|
|
xo_t += xo2 |
|
87
|
|
|
xc_t += xc2 |
|
88
|
|
|
# LSTM core step |
|
89
|
|
|
i_t = self._inner_act(xi_t + T.dot(h_tm1, self.U_i) + self.b_i) |
|
90
|
|
|
f_t = self._inner_act(xf_t + T.dot(h_tm1, self.U_f) + self.b_f) |
|
91
|
|
|
c_t = f_t * c_tm1 + i_t * self._outer_act(xc_t + T.dot(h_tm1, self.U_c) + self.b_c) |
|
92
|
|
|
o_t = self._inner_act(xo_t + T.dot(h_tm1, self.U_o) + self.b_o) |
|
93
|
|
|
h_t = o_t * self._outer_act(c_t) |
|
94
|
|
|
# Apply mask |
|
95
|
|
|
if "mask" in sequence_map: |
|
96
|
|
|
mask = sequence_map["mask"].dimshuffle(0, 'x') |
|
97
|
|
|
h_t = h_t * mask + h_tm1 * (1 - mask) |
|
98
|
|
|
c_t = c_t * mask + c_tm1 * (1 - mask) |
|
99
|
|
|
return h_t, c_t |
|
100
|
|
|
|
|
101
|
|
|
def produce_input_sequences(self, x, mask=None, second_input=None): |
|
102
|
|
|
# Create sequence map |
|
103
|
|
|
self._sequence_map.clear() |
|
104
|
|
|
if self._input_type == "sequence": |
|
105
|
|
|
# Input vars |
|
106
|
|
|
xi = T.dot(x, self.W_i) |
|
107
|
|
|
xf = T.dot(x, self.W_f) |
|
108
|
|
|
xc = T.dot(x, self.W_c) |
|
109
|
|
|
xo = T.dot(x, self.W_o) |
|
110
|
|
|
self._sequence_map.update([("xi", xi), ("xf", xf), ("xc", xc), ("xo", xo)]) |
|
111
|
|
|
# Reset state |
|
112
|
|
|
if self.reset_state_for_input != None: |
|
113
|
|
|
self._sequence_map["x"] = x |
|
114
|
|
|
# Add mask |
|
115
|
|
|
if mask: |
|
116
|
|
|
self._sequence_map["mask"] = mask |
|
117
|
|
|
elif self.mask: |
|
118
|
|
|
self._sequence_map["mask"] = self.mask |
|
119
|
|
|
# Add second input |
|
120
|
|
|
if self.second_input and not second_input: |
|
121
|
|
|
second_input = self.second_input |
|
122
|
|
|
if second_input: |
|
123
|
|
|
xi2 = T.dot(second_input, self.W_i2) |
|
124
|
|
|
xf2 = T.dot(second_input, self.W_f2) |
|
125
|
|
|
xc2 = T.dot(second_input, self.W_c2) |
|
126
|
|
|
xo2 = T.dot(second_input, self.W_o2) |
|
127
|
|
|
self._sequence_map.update([("xi2", xi2), ("xf2", xf2), ("xc2", xc2), ("xo2", xo2)]) |
|
128
|
|
|
return self._sequence_map.values() |
|
129
|
|
|
|
|
130
|
|
|
def produce_initial_states(self, x): |
|
131
|
|
|
if self.persistent_state: |
|
132
|
|
|
return self.state_h, self.state_m |
|
133
|
|
|
else: |
|
134
|
|
|
h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size) |
|
135
|
|
|
m0 = h0 |
|
136
|
|
|
return h0, m0 |
|
137
|
|
|
|
|
138
|
|
|
def compute_tensor(self, x): |
|
139
|
|
|
h0, m0 = self.produce_initial_states(x) |
|
140
|
|
|
if self._input_type == "sequence": |
|
141
|
|
|
# Move middle dimension to left-most position |
|
142
|
|
|
# (sequence, batch, value) |
|
143
|
|
|
x = x.dimshuffle((1,0,2)) |
|
144
|
|
|
sequences = self.produce_input_sequences(x) |
|
145
|
|
|
else: |
|
146
|
|
|
h0 = x |
|
147
|
|
|
sequences = self.produce_input_sequences(None) |
|
148
|
|
|
|
|
149
|
|
|
[hiddens, memories], _ = theano.scan( |
|
150
|
|
|
self.step, |
|
151
|
|
|
sequences=sequences, |
|
152
|
|
|
outputs_info=[h0, m0], |
|
153
|
|
|
n_steps=self._steps, |
|
154
|
|
|
go_backwards=self.go_backwards |
|
155
|
|
|
) |
|
156
|
|
|
|
|
157
|
|
|
# Save persistent state |
|
158
|
|
|
if self.persistent_state: |
|
159
|
|
|
self.register_updates((self.state_h, hiddens[-1])) |
|
160
|
|
|
self.register_updates((self.state_m, memories[-1])) |
|
161
|
|
|
|
|
162
|
|
|
if self._output_type == "one": |
|
163
|
|
|
return hiddens[-1] |
|
164
|
|
|
elif self._output_type == "sequence": |
|
165
|
|
|
return hiddens.dimshuffle((1,0,2)) |
|
166
|
|
|
|
|
167
|
|
|
|
|
168
|
|
|
def prepare(self): |
|
169
|
|
|
self._setup_params() |
|
170
|
|
|
self._setup_functions() |
|
171
|
|
|
|
|
172
|
|
|
def _setup_functions(self): |
|
173
|
|
|
self._inner_act = build_activation(self._inner_activation) |
|
174
|
|
|
self._outer_act = build_activation(self._outer_activation) |
|
175
|
|
|
|
|
176
|
|
|
def _setup_params(self): |
|
177
|
|
|
self.output_dim = self._hidden_size |
|
178
|
|
|
|
|
179
|
|
|
self.W_i = self.create_weight(self.input_dim, self._hidden_size, "wi", initializer=self._outer_init) |
|
180
|
|
|
self.U_i = self.create_weight(self._hidden_size, self._hidden_size, "ui", initializer=self._inner_init) |
|
181
|
|
|
self.b_i = self.create_bias(self._hidden_size, "i") |
|
182
|
|
|
|
|
183
|
|
|
self.W_f = self.create_weight(self.input_dim, self._hidden_size, "wf", initializer=self._outer_init) |
|
184
|
|
|
self.U_f = self.create_weight(self._hidden_size, self._hidden_size, "uf", initializer=self._inner_init) |
|
185
|
|
|
self.b_f = self.create_bias(self._hidden_size, "f") |
|
186
|
|
|
if self.forget_bias > 0: |
|
187
|
|
|
self.b_f.set_value(np.ones((self._hidden_size,), dtype=FLOATX)) |
|
188
|
|
|
|
|
189
|
|
|
self.W_c = self.create_weight(self.input_dim, self._hidden_size, "wc", initializer=self._outer_init) |
|
190
|
|
|
self.U_c = self.create_weight(self._hidden_size, self._hidden_size, "uc", initializer=self._inner_init) |
|
191
|
|
|
self.b_c = self.create_bias(self._hidden_size, "c") |
|
192
|
|
|
|
|
193
|
|
|
self.W_o = self.create_weight(self.input_dim, self._hidden_size, "wo", initializer=self._outer_init) |
|
194
|
|
|
self.U_o = self.create_weight(self._hidden_size, self._hidden_size, "uo", initializer=self._inner_init) |
|
195
|
|
|
self.b_o = self.create_bias(self._hidden_size, suffix="o") |
|
196
|
|
|
|
|
197
|
|
|
|
|
198
|
|
|
if self._input_type == "sequence": |
|
199
|
|
|
self.register_parameters(self.W_i, self.U_i, self.b_i, |
|
200
|
|
|
self.W_c, self.U_c, self.b_c, |
|
201
|
|
|
self.W_f, self.U_f, self.b_f, |
|
202
|
|
|
self.W_o, self.U_o, self.b_o) |
|
203
|
|
|
else: |
|
204
|
|
|
self.register_parameters(self.U_i, self.b_i, |
|
205
|
|
|
self.U_c, self.b_c, |
|
206
|
|
|
self.U_f, self.b_f, |
|
207
|
|
|
self.U_o, self.b_o) |
|
208
|
|
|
# Second input |
|
209
|
|
|
if self.second_input_size: |
|
210
|
|
|
self.W_i2 = self.create_weight(self.second_input_size, self._hidden_size, "wi2", initializer=self._outer_init) |
|
211
|
|
|
self.W_f2 = self.create_weight(self.second_input_size, self._hidden_size, "wf2", initializer=self._outer_init) |
|
212
|
|
|
self.W_c2 = self.create_weight(self.second_input_size, self._hidden_size, "wc2", initializer=self._outer_init) |
|
213
|
|
|
self.W_o2 = self.create_weight(self.second_input_size, self._hidden_size, "wo2", initializer=self._outer_init) |
|
214
|
|
|
self.register_parameters(self.W_i2, self.W_f2, self.W_c2, self.W_o2) |
|
215
|
|
|
|
|
216
|
|
|
# Create persistent state |
|
217
|
|
|
if self.persistent_state: |
|
218
|
|
|
self.state_h = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_h") |
|
219
|
|
|
self.state_m = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_m") |
|
220
|
|
|
self.register_free_parameters(self.state_h, self.state_m) |
|
221
|
|
|
else: |
|
222
|
|
|
self.state_h = None |
|
223
|
|
|
self.state_m = None |
|
224
|
|
|
|