1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
import numpy as np |
5
|
|
|
import theano.tensor as T |
6
|
|
|
from recurrent import RecurrentLayer |
7
|
|
|
from deepy.core.tensor_conversion import neural_computation |
8
|
|
|
from deepy.core.env import env |
9
|
|
|
|
10
|
|
|
|
11
|
|
|
class LSTM(RecurrentLayer): |
12
|
|
|
""" |
13
|
|
|
Long short-term memory layer. |
14
|
|
|
""" |
15
|
|
|
|
16
|
|
|
def __init__(self, hidden_size, init_forget_bias=1, **kwargs): |
17
|
|
|
kwargs["hidden_size"] = hidden_size |
18
|
|
|
super(LSTM, self).__init__("LSTM", ["state", "lstm_cell"], **kwargs) |
19
|
|
|
self._init_forget_bias = 1 |
20
|
|
|
|
21
|
|
|
@neural_computation |
22
|
|
|
def compute_new_state(self, step_inputs): |
23
|
|
|
xi_t, xf_t, xo_t, xc_t, h_tm1, c_tm1 = map(step_inputs.get, ["xi", "xf", "xc", "xo", "state", "lstm_cell"]) |
24
|
|
|
if not xi_t: |
25
|
|
|
xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0 |
26
|
|
|
|
27
|
|
|
# LSTM core step |
28
|
|
|
i_t = self.gate_activate(xi_t + T.dot(h_tm1, self.U_i) + self.b_i) |
29
|
|
|
f_t = self.gate_activate(xf_t + T.dot(h_tm1, self.U_f) + self.b_f) |
30
|
|
|
c_t = f_t * c_tm1 + i_t * self.activate(xc_t + T.dot(h_tm1, self.U_c) + self.b_c) |
31
|
|
|
o_t = self.gate_activate(xo_t + T.dot(h_tm1, self.U_o) + self.b_o) |
32
|
|
|
h_t = o_t * self.activate(c_t) |
33
|
|
|
|
34
|
|
|
return {"state": h_t, "lstm_cell": c_t} |
35
|
|
|
|
36
|
|
|
@neural_computation |
37
|
|
|
def merge_inputs(self, input_var, additional_inputs=None): |
38
|
|
|
if not additional_inputs: |
39
|
|
|
additional_inputs = [] |
40
|
|
|
all_inputs = filter(bool, [input_var] + additional_inputs) |
41
|
|
|
if not all_inputs: |
42
|
|
|
return {} |
43
|
|
|
last_dim_id = all_inputs[0].ndim - 1 |
44
|
|
|
merged_input = T.concatenate(all_inputs, axis=last_dim_id) |
45
|
|
|
merged_inputs = { |
46
|
|
|
"xi": T.dot(merged_input, self.W_i), |
47
|
|
|
"xf": T.dot(merged_input, self.W_f), |
48
|
|
|
"xc": T.dot(merged_input, self.W_c), |
49
|
|
|
"xo": T.dot(merged_input, self.W_o), |
50
|
|
|
} |
51
|
|
|
return merged_inputs |
52
|
|
|
|
53
|
|
|
|
54
|
|
|
def prepare(self): |
55
|
|
|
if self._input_type == "sequence": |
56
|
|
|
all_input_dims = [self.input_dim] + self.additional_input_dims |
57
|
|
|
else: |
58
|
|
|
all_input_dims = self.additional_input_dims |
59
|
|
|
summed_input_dim = sum(all_input_dims, 0) |
60
|
|
|
self.output_dim = self.hidden_size |
61
|
|
|
|
62
|
|
|
self.W_i = self.create_weight(summed_input_dim, self.hidden_size, "wi", initializer=self.outer_init) |
63
|
|
|
self.U_i = self.create_weight(self.hidden_size, self.hidden_size, "ui", initializer=self.inner_init) |
64
|
|
|
self.b_i = self.create_bias(self.hidden_size, "bi") |
65
|
|
|
|
66
|
|
|
self.W_f = self.create_weight(summed_input_dim, self.hidden_size, "wf", initializer=self.outer_init) |
67
|
|
|
self.U_f = self.create_weight(self.hidden_size, self.hidden_size, "uf", initializer=self.inner_init) |
68
|
|
|
self.b_f = self.create_bias(self.hidden_size, "bf") |
69
|
|
|
self.b_f.set_value(np.ones((self.hidden_size,) * self._init_forget_bias, dtype=env.FLOATX)) |
70
|
|
|
|
71
|
|
|
self.W_c = self.create_weight(summed_input_dim, self.hidden_size, "wc", initializer=self.outer_init) |
72
|
|
|
self.U_c = self.create_weight(self.hidden_size, self.hidden_size, "uc", initializer=self.inner_init) |
73
|
|
|
self.b_c = self.create_bias(self.hidden_size, "bc") |
74
|
|
|
|
75
|
|
|
self.W_o = self.create_weight(summed_input_dim, self.hidden_size, "wo", initializer=self.outer_init) |
76
|
|
|
self.U_o = self.create_weight(self.hidden_size, self.hidden_size, "uo", initializer=self.inner_init) |
77
|
|
|
self.b_o = self.create_bias(self.hidden_size, "bo") |
78
|
|
|
|
79
|
|
|
|
80
|
|
|
if summed_input_dim > 0: |
81
|
|
|
self.register_parameters(self.W_i, self.U_i, self.b_i, |
82
|
|
|
self.W_c, self.U_c, self.b_c, |
83
|
|
|
self.W_f, self.U_f, self.b_f, |
84
|
|
|
self.W_o, self.U_o, self.b_o) |
85
|
|
|
else: |
86
|
|
|
self.register_parameters(self.U_i, self.b_i, |
87
|
|
|
self.U_c, self.b_c, |
88
|
|
|
self.U_f, self.b_f, |
89
|
|
|
self.U_o, self.b_o) |
90
|
|
|
|
91
|
|
|
|