1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
import numpy as np |
5
|
|
|
import theano.tensor as T |
6
|
|
|
from recurrent import RecurrentLayer |
7
|
|
|
from deepy.utils import neural_computation, FLOATX |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
class LSTM(RecurrentLayer): |
11
|
|
|
""" |
12
|
|
|
Long short-term memory layer. |
13
|
|
|
""" |
14
|
|
|
|
15
|
|
|
def __init__(self, hidden_size, init_forget_bias=1, **kwargs): |
16
|
|
|
kwargs["hidden_size"] = hidden_size |
17
|
|
|
super(LSTM, self).__init__("LSTM", ["state", "lstm_cell"], **kwargs) |
18
|
|
|
self._init_forget_bias = 1 |
19
|
|
|
|
20
|
|
|
@neural_computation |
21
|
|
|
def compute_new_state(self, step_inputs): |
22
|
|
|
xi_t, xf_t, xo_t, xc_t, h_tm1, c_tm1 = map(step_inputs.get, ["xi", "xf", "xc", "xo", "state", "lstm_cell"]) |
23
|
|
|
if not xi_t: |
24
|
|
|
xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0 |
25
|
|
|
|
26
|
|
|
# LSTM core step |
27
|
|
|
i_t = self.gate_activate(xi_t + T.dot(h_tm1, self.U_i) + self.b_i) |
28
|
|
|
f_t = self.gate_activate(xf_t + T.dot(h_tm1, self.U_f) + self.b_f) |
29
|
|
|
c_t = f_t * c_tm1 + i_t * self.activate(xc_t + T.dot(h_tm1, self.U_c) + self.b_c) |
30
|
|
|
o_t = self.gate_activate(xo_t + T.dot(h_tm1, self.U_o) + self.b_o) |
31
|
|
|
h_t = o_t * self.activate(c_t) |
32
|
|
|
|
33
|
|
|
return {"state": h_t, "lstm_cell": c_t} |
34
|
|
|
|
35
|
|
|
@neural_computation |
36
|
|
|
def merge_inputs(self, input_var, additional_inputs=None): |
37
|
|
|
if not additional_inputs: |
38
|
|
|
additional_inputs = [] |
39
|
|
|
all_inputs = filter(bool, [input_var] + additional_inputs) |
40
|
|
|
if not all_inputs: |
41
|
|
|
return {} |
42
|
|
|
last_dim_id = input_var.ndim - 1 |
43
|
|
|
merged_input = T.concatenate(all_inputs, axis=last_dim_id) |
44
|
|
|
merged_inputs = { |
45
|
|
|
"xi": T.dot(merged_input, self.W_i), |
46
|
|
|
"xf": T.dot(merged_input, self.W_f), |
47
|
|
|
"xc": T.dot(merged_input, self.W_c), |
48
|
|
|
"xo": T.dot(merged_input, self.W_o), |
49
|
|
|
} |
50
|
|
|
return merged_inputs |
51
|
|
|
|
52
|
|
|
|
53
|
|
|
def prepare(self): |
54
|
|
|
if self._input_type == "sequence": |
55
|
|
|
all_input_dims = [self.input_dim] + self.additional_input_dims |
56
|
|
|
else: |
57
|
|
|
all_input_dims = self.additional_input_dims |
58
|
|
|
summed_input_dim = sum(all_input_dims, 0) |
59
|
|
|
self.output_dim = self.hidden_size |
60
|
|
|
|
61
|
|
|
self.W_i = self.create_weight(summed_input_dim, self.hidden_size, "wi", initializer=self.outer_init) |
62
|
|
|
self.U_i = self.create_weight(self.hidden_size, self.hidden_size, "ui", initializer=self.inner_init) |
63
|
|
|
self.b_i = self.create_bias(self.hidden_size, "i") |
64
|
|
|
|
65
|
|
|
self.W_f = self.create_weight(summed_input_dim, self.hidden_size, "wf", initializer=self.outer_init) |
66
|
|
|
self.U_f = self.create_weight(self.hidden_size, self.hidden_size, "uf", initializer=self.inner_init) |
67
|
|
|
self.b_f = self.create_bias(self.hidden_size, "f") |
68
|
|
|
self.b_f.set_value(np.ones((self.hidden_size,) * self._init_forget_bias, dtype=FLOATX)) |
69
|
|
|
|
70
|
|
|
self.W_c = self.create_weight(summed_input_dim, self.hidden_size, "wc", initializer=self.outer_init) |
71
|
|
|
self.U_c = self.create_weight(self.hidden_size, self.hidden_size, "uc", initializer=self.inner_init) |
72
|
|
|
self.b_c = self.create_bias(self.hidden_size, "c") |
73
|
|
|
|
74
|
|
|
self.W_o = self.create_weight(summed_input_dim, self.hidden_size, "wo", initializer=self.outer_init) |
75
|
|
|
self.U_o = self.create_weight(self.hidden_size, self.hidden_size, "uo", initializer=self.inner_init) |
76
|
|
|
self.b_o = self.create_bias(self.hidden_size, suffix="o") |
77
|
|
|
|
78
|
|
|
|
79
|
|
|
if summed_input_dim > 0: |
80
|
|
|
self.register_parameters(self.W_i, self.U_i, self.b_i, |
81
|
|
|
self.W_c, self.U_c, self.b_c, |
82
|
|
|
self.W_f, self.U_f, self.b_f, |
83
|
|
|
self.W_o, self.U_o, self.b_o) |
84
|
|
|
else: |
85
|
|
|
self.register_parameters(self.U_i, self.b_i, |
86
|
|
|
self.U_c, self.b_c, |
87
|
|
|
self.U_f, self.b_f, |
88
|
|
|
self.U_o, self.b_o) |
89
|
|
|
|
90
|
|
|
|