1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
from . import NeuralLayer |
5
|
|
|
from var import NeuralVariable |
6
|
|
|
from deepy.utils import build_activation, FLOATX |
7
|
|
|
import numpy as np |
8
|
|
|
import theano |
9
|
|
|
import theano.tensor as T |
10
|
|
|
from collections import OrderedDict |
11
|
|
|
|
12
|
|
|
OUTPUT_TYPES = ["sequence", "one"] |
13
|
|
|
INPUT_TYPES = ["sequence", "one"] |
14
|
|
|
|
15
|
|
|
class LSTM(NeuralLayer): |
16
|
|
|
""" |
17
|
|
|
Long short-term memory layer. |
18
|
|
|
""" |
19
|
|
|
|
20
|
|
|
def __init__(self, hidden_size, input_type="sequence", output_type="sequence", |
21
|
|
|
inner_activation="sigmoid", outer_activation="tanh", |
22
|
|
|
inner_init=None, outer_init=None, steps=None, |
23
|
|
|
go_backwards=False, |
24
|
|
|
persistent_state=False, batch_size=0, |
25
|
|
|
reset_state_for_input=None, forget_bias=1, |
26
|
|
|
mask=None, |
27
|
|
|
second_input=None, second_input_size=None): |
28
|
|
|
super(LSTM, self).__init__("lstm") |
29
|
|
|
self._hidden_size = hidden_size |
30
|
|
|
self._input_type = input_type |
31
|
|
|
self._output_type = output_type |
32
|
|
|
self._inner_activation = inner_activation |
33
|
|
|
self._outer_activation = outer_activation |
34
|
|
|
self._inner_init = inner_init |
35
|
|
|
self._outer_init = outer_init |
36
|
|
|
self._steps = steps |
37
|
|
|
self.persistent_state = persistent_state |
38
|
|
|
self.reset_state_for_input = reset_state_for_input |
39
|
|
|
self.batch_size = batch_size |
40
|
|
|
self.go_backwards = go_backwards |
41
|
|
|
# mask |
42
|
|
|
mask = mask.tensor if type(mask) == NeuralVariable else mask |
43
|
|
|
self.mask = mask.dimshuffle((1,0)) if mask else None |
44
|
|
|
self._sequence_map = OrderedDict() |
45
|
|
|
# second input |
46
|
|
|
if type(second_input) == NeuralVariable: |
47
|
|
|
second_input_size = second_input.dim() |
48
|
|
|
second_input = second_input.tensor |
49
|
|
|
|
50
|
|
|
self.second_input = second_input |
51
|
|
|
self.second_input_size = second_input_size |
52
|
|
|
self.forget_bias = forget_bias |
53
|
|
|
if input_type not in INPUT_TYPES: |
54
|
|
|
raise Exception("Input type of LSTM is wrong: %s" % input_type) |
55
|
|
|
if output_type not in OUTPUT_TYPES: |
56
|
|
|
raise Exception("Output type of LSTM is wrong: %s" % output_type) |
57
|
|
|
if self.persistent_state and not self.batch_size: |
58
|
|
|
raise Exception("Batch size must be set for persistent state mode") |
59
|
|
|
if mask and input_type == "one": |
60
|
|
|
raise Exception("Mask only works with sequence input") |
61
|
|
|
|
62
|
|
|
def _auto_reset_memories(self, x, h, m): |
63
|
|
|
reset_matrix = T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x') |
64
|
|
|
h = h * reset_matrix |
65
|
|
|
m = m * reset_matrix |
66
|
|
|
return h, m |
67
|
|
|
|
68
|
|
|
def step(self, *vars): |
69
|
|
|
# Parse sequence |
70
|
|
|
sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)])) |
71
|
|
|
h_tm1, c_tm1 = vars[-2:] |
72
|
|
|
# Reset state |
73
|
|
|
if self.reset_state_for_input != None: |
74
|
|
|
h_tm1, c_tm1 = self._auto_reset_memories(sequence_map["x"], h_tm1, c_tm1) |
75
|
|
|
|
76
|
|
|
if self._input_type == "sequence": |
77
|
|
|
xi_t, xf_t, xo_t, xc_t = map(sequence_map.get, ["xi", "xf", "xo", "xc"]) |
78
|
|
|
else: |
79
|
|
|
xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0 |
80
|
|
|
|
81
|
|
|
# Add second input |
82
|
|
|
if "xi2" in sequence_map: |
83
|
|
|
xi2, xf2, xo2, xc2 = map(sequence_map.get, ["xi2", "xf2", "xo2", "xc2"]) |
84
|
|
|
xi_t += xi2 |
85
|
|
|
xf_t += xf2 |
86
|
|
|
xo_t += xo2 |
87
|
|
|
xc_t += xc2 |
88
|
|
|
# LSTM core step |
89
|
|
|
i_t = self._inner_act(xi_t + T.dot(h_tm1, self.U_i) + self.b_i) |
90
|
|
|
f_t = self._inner_act(xf_t + T.dot(h_tm1, self.U_f) + self.b_f) |
91
|
|
|
c_t = f_t * c_tm1 + i_t * self._outer_act(xc_t + T.dot(h_tm1, self.U_c) + self.b_c) |
92
|
|
|
o_t = self._inner_act(xo_t + T.dot(h_tm1, self.U_o) + self.b_o) |
93
|
|
|
h_t = o_t * self._outer_act(c_t) |
94
|
|
|
# Apply mask |
95
|
|
|
if "mask" in sequence_map: |
96
|
|
|
mask = sequence_map["mask"].dimshuffle(0, 'x') |
97
|
|
|
h_t = h_t * mask + h_tm1 * (1 - mask) |
98
|
|
|
c_t = c_t * mask + c_tm1 * (1 - mask) |
99
|
|
|
return h_t, c_t |
100
|
|
|
|
101
|
|
|
def produce_input_sequences(self, x, mask=None, second_input=None): |
102
|
|
|
# Create sequence map |
103
|
|
|
self._sequence_map.clear() |
104
|
|
|
if self._input_type == "sequence": |
105
|
|
|
# Input vars |
106
|
|
|
xi = T.dot(x, self.W_i) |
107
|
|
|
xf = T.dot(x, self.W_f) |
108
|
|
|
xc = T.dot(x, self.W_c) |
109
|
|
|
xo = T.dot(x, self.W_o) |
110
|
|
|
self._sequence_map.update([("xi", xi), ("xf", xf), ("xc", xc), ("xo", xo)]) |
111
|
|
|
# Reset state |
112
|
|
|
if self.reset_state_for_input != None: |
113
|
|
|
self._sequence_map["x"] = x |
114
|
|
|
# Add mask |
115
|
|
|
if mask: |
116
|
|
|
self._sequence_map["mask"] = mask |
117
|
|
|
elif self.mask: |
118
|
|
|
self._sequence_map["mask"] = self.mask |
119
|
|
|
# Add second input |
120
|
|
|
if self.second_input and not second_input: |
121
|
|
|
second_input = self.second_input |
122
|
|
|
if second_input: |
123
|
|
|
xi2 = T.dot(second_input, self.W_i2) |
124
|
|
|
xf2 = T.dot(second_input, self.W_f2) |
125
|
|
|
xc2 = T.dot(second_input, self.W_c2) |
126
|
|
|
xo2 = T.dot(second_input, self.W_o2) |
127
|
|
|
self._sequence_map.update([("xi2", xi2), ("xf2", xf2), ("xc2", xc2), ("xo2", xo2)]) |
128
|
|
|
return self._sequence_map.values() |
129
|
|
|
|
130
|
|
|
def produce_initial_states(self, x): |
131
|
|
|
if self.persistent_state: |
132
|
|
|
return self.state_h, self.state_m |
133
|
|
|
else: |
134
|
|
|
h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size) |
135
|
|
|
m0 = h0 |
136
|
|
|
return h0, m0 |
137
|
|
|
|
138
|
|
|
def compute_tensor(self, x): |
139
|
|
|
h0, m0 = self.produce_initial_states(x) |
140
|
|
|
if self._input_type == "sequence": |
141
|
|
|
# Move middle dimension to left-most position |
142
|
|
|
# (sequence, batch, value) |
143
|
|
|
x = x.dimshuffle((1,0,2)) |
144
|
|
|
sequences = self.produce_input_sequences(x) |
145
|
|
|
else: |
146
|
|
|
h0 = x |
147
|
|
|
sequences = self.produce_input_sequences(None) |
148
|
|
|
|
149
|
|
|
[hiddens, memories], _ = theano.scan( |
150
|
|
|
self.step, |
151
|
|
|
sequences=sequences, |
152
|
|
|
outputs_info=[h0, m0], |
153
|
|
|
n_steps=self._steps, |
154
|
|
|
go_backwards=self.go_backwards |
155
|
|
|
) |
156
|
|
|
|
157
|
|
|
# Save persistent state |
158
|
|
|
if self.persistent_state: |
159
|
|
|
self.register_updates((self.state_h, hiddens[-1])) |
160
|
|
|
self.register_updates((self.state_m, memories[-1])) |
161
|
|
|
|
162
|
|
|
if self._output_type == "one": |
163
|
|
|
return hiddens[-1] |
164
|
|
|
elif self._output_type == "sequence": |
165
|
|
|
return hiddens.dimshuffle((1,0,2)) |
166
|
|
|
|
167
|
|
|
|
168
|
|
|
def prepare(self): |
169
|
|
|
self._setup_params() |
170
|
|
|
self._setup_functions() |
171
|
|
|
|
172
|
|
|
def _setup_functions(self): |
173
|
|
|
self._inner_act = build_activation(self._inner_activation) |
174
|
|
|
self._outer_act = build_activation(self._outer_activation) |
175
|
|
|
|
176
|
|
|
def _setup_params(self): |
177
|
|
|
self.output_dim = self._hidden_size |
178
|
|
|
|
179
|
|
|
self.W_i = self.create_weight(self.input_dim, self._hidden_size, "wi", initializer=self._outer_init) |
180
|
|
|
self.U_i = self.create_weight(self._hidden_size, self._hidden_size, "ui", initializer=self._inner_init) |
181
|
|
|
self.b_i = self.create_bias(self._hidden_size, "i") |
182
|
|
|
|
183
|
|
|
self.W_f = self.create_weight(self.input_dim, self._hidden_size, "wf", initializer=self._outer_init) |
184
|
|
|
self.U_f = self.create_weight(self._hidden_size, self._hidden_size, "uf", initializer=self._inner_init) |
185
|
|
|
self.b_f = self.create_bias(self._hidden_size, "f") |
186
|
|
|
if self.forget_bias > 0: |
187
|
|
|
self.b_f.set_value(np.ones((self._hidden_size,), dtype=FLOATX)) |
188
|
|
|
|
189
|
|
|
self.W_c = self.create_weight(self.input_dim, self._hidden_size, "wc", initializer=self._outer_init) |
190
|
|
|
self.U_c = self.create_weight(self._hidden_size, self._hidden_size, "uc", initializer=self._inner_init) |
191
|
|
|
self.b_c = self.create_bias(self._hidden_size, "c") |
192
|
|
|
|
193
|
|
|
self.W_o = self.create_weight(self.input_dim, self._hidden_size, "wo", initializer=self._outer_init) |
194
|
|
|
self.U_o = self.create_weight(self._hidden_size, self._hidden_size, "uo", initializer=self._inner_init) |
195
|
|
|
self.b_o = self.create_bias(self._hidden_size, suffix="o") |
196
|
|
|
|
197
|
|
|
|
198
|
|
|
if self._input_type == "sequence": |
199
|
|
|
self.register_parameters(self.W_i, self.U_i, self.b_i, |
200
|
|
|
self.W_c, self.U_c, self.b_c, |
201
|
|
|
self.W_f, self.U_f, self.b_f, |
202
|
|
|
self.W_o, self.U_o, self.b_o) |
203
|
|
|
else: |
204
|
|
|
self.register_parameters(self.U_i, self.b_i, |
205
|
|
|
self.U_c, self.b_c, |
206
|
|
|
self.U_f, self.b_f, |
207
|
|
|
self.U_o, self.b_o) |
208
|
|
|
# Second input |
209
|
|
|
if self.second_input_size: |
210
|
|
|
self.W_i2 = self.create_weight(self.second_input_size, self._hidden_size, "wi2", initializer=self._outer_init) |
211
|
|
|
self.W_f2 = self.create_weight(self.second_input_size, self._hidden_size, "wf2", initializer=self._outer_init) |
212
|
|
|
self.W_c2 = self.create_weight(self.second_input_size, self._hidden_size, "wc2", initializer=self._outer_init) |
213
|
|
|
self.W_o2 = self.create_weight(self.second_input_size, self._hidden_size, "wo2", initializer=self._outer_init) |
214
|
|
|
self.register_parameters(self.W_i2, self.W_f2, self.W_c2, self.W_o2) |
215
|
|
|
|
216
|
|
|
# Create persistent state |
217
|
|
|
if self.persistent_state: |
218
|
|
|
self.state_h = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_h") |
219
|
|
|
self.state_m = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_m") |
220
|
|
|
self.register_free_parameters(self.state_h, self.state_m) |
221
|
|
|
else: |
222
|
|
|
self.state_h = None |
223
|
|
|
self.state_m = None |
224
|
|
|
|