Completed
Push — master ( 6bb71e...eef17b )
by Raphael
01:44
created

LSTM.get_step_inputs()   D

Complexity

Conditions 8

Size

Total Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 8
c 1
b 0
f 0
dl 0
loc 28
rs 4
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import NeuralLayer
5
from neural_var import NeuralVariable
6
from deepy.utils import build_activation, FLOATX
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
from collections import OrderedDict
11
12
OUTPUT_TYPES = ["sequence", "one"]
13
INPUT_TYPES = ["sequence", "one"]
14
15
class LSTM(NeuralLayer):
16
    """
17
    Long short-term memory layer.
18
    """
19
20
    def __init__(self, hidden_size, input_type="sequence", output_type="sequence",
21
                 inner_activation="sigmoid", outer_activation="tanh",
22
                 inner_init=None, outer_init=None, steps=None,
23
                 go_backwards=False,
24
                 persistent_state=False, batch_size=0,
25
                 reset_state_for_input=None, forget_bias=1,
26
                 mask=None,
27
                 second_input=None, second_input_size=None):
28
        super(LSTM, self).__init__("lstm")
29
        self.main_state = 'state'
30
        self._hidden_size = hidden_size
31
        self._input_type = input_type
32
        self._output_type = output_type
33
        self._inner_activation = inner_activation
34
        self._outer_activation = outer_activation
35
        self._inner_init = inner_init
36
        self._outer_init = outer_init
37
        self._steps = steps
38
        self.persistent_state = persistent_state
39
        self.reset_state_for_input = reset_state_for_input
40
        self.batch_size = batch_size
41
        self.go_backwards = go_backwards
42
        # mask
43
        mask = mask.tensor if type(mask) == NeuralVariable else mask
44
        self.mask = mask.dimshuffle((1,0)) if mask else None
45
        self._sequence_map = OrderedDict()
46
        # second input
47
        if type(second_input) == NeuralVariable:
48
            second_input_size = second_input.dim()
49
            second_input = second_input.tensor
50
51
        self.second_input = second_input
52
        self.second_input_size = second_input_size
53
        self.forget_bias = forget_bias
54
        if input_type not in INPUT_TYPES:
55
            raise Exception("Input type of LSTM is wrong: %s" % input_type)
56
        if output_type not in OUTPUT_TYPES:
57
            raise Exception("Output type of LSTM is wrong: %s" % output_type)
58
        if self.persistent_state and not self.batch_size:
59
            raise Exception("Batch size must be set for persistent state mode")
60
        if mask and input_type == "one":
61
            raise Exception("Mask only works with sequence input")
62
63
    def _auto_reset_memories(self, x, h, m):
64
        reset_matrix = T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x')
65
        h = h * reset_matrix
66
        m = m * reset_matrix
67
        return h, m
68
69
    def step(self, *vars):
70
        # Parse sequence
71
        sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)]))
72
        h_tm1, c_tm1 = vars[-2:]
73
        # Reset state
74
        if self.reset_state_for_input != None:
75
            h_tm1, c_tm1 = self._auto_reset_memories(sequence_map["x"], h_tm1, c_tm1)
76
77
        if self._input_type == "sequence":
78
            xi_t, xf_t, xo_t, xc_t = map(sequence_map.get, ["xi", "xf", "xo", "xc"])
79
        else:
80
            xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0
81
82
        # Add second input
83
        if "xi2" in sequence_map:
84
            xi2, xf2, xo2, xc2 = map(sequence_map.get, ["xi2", "xf2", "xo2", "xc2"])
85
            xi_t += xi2
86
            xf_t += xf2
87
            xo_t += xo2
88
            xc_t += xc2
89
        # LSTM core step
90
        i_t = self._inner_act(xi_t + T.dot(h_tm1, self.U_i) + self.b_i)
91
        f_t = self._inner_act(xf_t + T.dot(h_tm1, self.U_f) + self.b_f)
92
        c_t = f_t * c_tm1 + i_t * self._outer_act(xc_t + T.dot(h_tm1, self.U_c) + self.b_c)
93
        o_t = self._inner_act(xo_t + T.dot(h_tm1, self.U_o) + self.b_o)
94
        h_t = o_t * self._outer_act(c_t)
95
        # Apply mask
96
        if "mask" in sequence_map:
97
            mask = sequence_map["mask"].dimshuffle(0, 'x')
98
            h_t = h_t * mask + h_tm1 * (1 - mask)
99
            c_t = c_t * mask + c_tm1 * (1 - mask)
100
        return h_t, c_t
101
102
    def get_step_inputs(self, x, mask=None, second_input=None):
103
        # Create sequence map
104
        self._sequence_map.clear()
105
        if self._input_type == "sequence":
106
            # Input vars
107
            xi = T.dot(x, self.W_i)
108
            xf = T.dot(x, self.W_f)
109
            xc = T.dot(x, self.W_c)
110
            xo = T.dot(x, self.W_o)
111
            self._sequence_map.update([("xi", xi), ("xf", xf), ("xc", xc), ("xo", xo)])
112
        # Reset state
113
        if self.reset_state_for_input != None:
114
            self._sequence_map["x"] = x
115
        # Add mask
116
        if mask:
117
            self._sequence_map["mask"] = mask
118
        elif self.mask:
119
            self._sequence_map["mask"] = self.mask
120
        # Add second input
121
        if self.second_input and not second_input:
122
            second_input = self.second_input
123
        if second_input:
124
            xi2 = T.dot(second_input, self.W_i2)
125
            xf2 = T.dot(second_input, self.W_f2)
126
            xc2 = T.dot(second_input, self.W_c2)
127
            xo2 = T.dot(second_input, self.W_o2)
128
            self._sequence_map.update([("xi2", xi2), ("xf2", xf2), ("xc2", xc2), ("xo2", xo2)])
129
        return self._sequence_map.values()
130
131
    def get_initial_states(self, x):
132
        if self.persistent_state:
133
            return self.state_h, self.state_m
134
        else:
135
            h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size)
136
            m0 = h0
137
            return {'state': h0, 'c': m0}
138
139
    def compute_tensor(self, x):
140
        h0, m0 = self.get_initial_states(x)
141
        if self._input_type == "sequence":
142
            # Move middle dimension to left-most position
143
            # (sequence, batch, value)
144
            x = x.dimshuffle((1,0,2))
145
            sequences = self.produce_input_sequences(x)
146
        else:
147
            h0 = x
148
            sequences = self.produce_input_sequences(None)
149
150
        [hiddens, memories], _ = theano.scan(
151
            self.step,
152
153
            sequences=sequences,
154
            outputs_info=[h0, m0],
155
            n_steps=self._steps,
156
            go_backwards=self.go_backwards
157
        )
158
159
        # Save persistent state
160
        if self.persistent_state:
161
            self.register_updates((self.state_h, hiddens[-1]))
162
            self.register_updates((self.state_m, memories[-1]))
163
164
        if self._output_type == "one":
165
            return hiddens[-1]
166
        elif self._output_type == "sequence":
167
            return hiddens.dimshuffle((1,0,2))
168
169
170
    def prepare(self):
171
        self._setup_params()
172
        self._setup_functions()
173
174
    def _setup_functions(self):
175
        self._inner_act = build_activation(self._inner_activation)
176
        self._outer_act = build_activation(self._outer_activation)
177
178
    def _setup_params(self):
179
        self.output_dim = self._hidden_size
180
181
        self.W_i = self.create_weight(self.input_dim, self._hidden_size, "wi", initializer=self._outer_init)
182
        self.U_i = self.create_weight(self._hidden_size, self._hidden_size, "ui", initializer=self._inner_init)
183
        self.b_i = self.create_bias(self._hidden_size, "i")
184
185
        self.W_f = self.create_weight(self.input_dim, self._hidden_size, "wf", initializer=self._outer_init)
186
        self.U_f = self.create_weight(self._hidden_size, self._hidden_size, "uf", initializer=self._inner_init)
187
        self.b_f = self.create_bias(self._hidden_size, "f")
188
        if self.forget_bias > 0:
189
            self.b_f.set_value(np.ones((self._hidden_size,), dtype=FLOATX))
190
191
        self.W_c = self.create_weight(self.input_dim, self._hidden_size, "wc", initializer=self._outer_init)
192
        self.U_c = self.create_weight(self._hidden_size, self._hidden_size, "uc", initializer=self._inner_init)
193
        self.b_c = self.create_bias(self._hidden_size, "c")
194
195
        self.W_o = self.create_weight(self.input_dim, self._hidden_size, "wo", initializer=self._outer_init)
196
        self.U_o = self.create_weight(self._hidden_size, self._hidden_size, "uo", initializer=self._inner_init)
197
        self.b_o = self.create_bias(self._hidden_size, suffix="o")
198
199
200
        if self._input_type == "sequence":
201
            self.register_parameters(self.W_i, self.U_i, self.b_i,
202
                                     self.W_c, self.U_c, self.b_c,
203
                                     self.W_f, self.U_f, self.b_f,
204
                                     self.W_o, self.U_o, self.b_o)
205
        else:
206
            self.register_parameters(self.U_i, self.b_i,
207
                                     self.U_c, self.b_c,
208
                                     self.U_f, self.b_f,
209
                                     self.U_o, self.b_o)
210
        # Second input
211
        if self.second_input_size:
212
            self.W_i2 = self.create_weight(self.second_input_size, self._hidden_size, "wi2", initializer=self._outer_init)
213
            self.W_f2 = self.create_weight(self.second_input_size, self._hidden_size, "wf2", initializer=self._outer_init)
214
            self.W_c2 = self.create_weight(self.second_input_size, self._hidden_size, "wc2", initializer=self._outer_init)
215
            self.W_o2 = self.create_weight(self.second_input_size, self._hidden_size, "wo2", initializer=self._outer_init)
216
            self.register_parameters(self.W_i2, self.W_f2, self.W_c2, self.W_o2)
217
218
        # Create persistent state
219
        if self.persistent_state:
220
            self.state_h = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_h")
221
            self.state_m = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_m")
222
            self.register_free_parameters(self.state_h, self.state_m)
223
        else:
224
            self.state_h = None
225
            self.state_m = None
226