Completed
Push — master ( c8682f...5bbe2a )
by Raphael
01:33
created

deepy.layers.LSTM.prepare()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from . import NeuralLayer
5
from variable import NeuralVar
6
from deepy.utils import build_activation, FLOATX
7
import numpy as np
8
import theano
9
import theano.tensor as T
10
from collections import OrderedDict
11
12
OUTPUT_TYPES = ["sequence", "one"]
13
INPUT_TYPES = ["sequence", "one"]
14
15
class LSTM(NeuralLayer):
16
    """
17
    Long short-term memory layer.
18
    """
19
20
    def __init__(self, hidden_size, input_type="sequence", output_type="sequence",
21
                 inner_activation="sigmoid", outer_activation="tanh",
22
                 inner_init=None, outer_init=None, steps=None,
23
                 go_backwards=False,
24
                 persistent_state=False, batch_size=0,
25
                 reset_state_for_input=None, forget_bias=1,
26
                 mask=None,
27
                 second_input=None, second_input_size=None):
28
        super(LSTM, self).__init__("lstm")
29
        self._hidden_size = hidden_size
30
        self._input_type = input_type
31
        self._output_type = output_type
32
        self._inner_activation = inner_activation
33
        self._outer_activation = outer_activation
34
        self._inner_init = inner_init
35
        self._outer_init = outer_init
36
        self._steps = steps
37
        self.persistent_state = persistent_state
38
        self.reset_state_for_input = reset_state_for_input
39
        self.batch_size = batch_size
40
        self.go_backwards = go_backwards
41
        # mask
42
        mask = mask.tensor if type(mask) == NeuralVar else mask
43
        self.mask = mask.dimshuffle((1,0)) if mask else None
44
        self._sequence_map = OrderedDict()
45
        # second input
46
        if type(second_input) == NeuralVar:
47
            second_input = second_input.tensor
48
            second_input_size = second_input.dim()
49
        self.second_input = second_input
50
        self.second_input_size = second_input_size
51
        self.forget_bias = forget_bias
52
        if input_type not in INPUT_TYPES:
53
            raise Exception("Input type of LSTM is wrong: %s" % input_type)
54
        if output_type not in OUTPUT_TYPES:
55
            raise Exception("Output type of LSTM is wrong: %s" % output_type)
56
        if self.persistent_state and not self.batch_size:
57
            raise Exception("Batch size must be set for persistent state mode")
58
        if mask and input_type == "one":
59
            raise Exception("Mask only works with sequence input")
60
61
    def _auto_reset_memories(self, x, h, m):
62
        reset_matrix = T.neq(x[:, self.reset_state_for_input], 1).dimshuffle(0, 'x')
63
        h = h * reset_matrix
64
        m = m * reset_matrix
65
        return h, m
66
67
    def step(self, *vars):
68
        # Parse sequence
69
        sequence_map = dict(zip(self._sequence_map.keys(), vars[:len(self._sequence_map)]))
70
        h_tm1, c_tm1 = vars[-2:]
71
        # Reset state
72
        if self.reset_state_for_input != None:
73
            h_tm1, c_tm1 = self._auto_reset_memories(sequence_map["x"], h_tm1, c_tm1)
74
75
        if self._input_type == "sequence":
76
            xi_t, xf_t, xo_t, xc_t = map(sequence_map.get, ["xi", "xf", "xo", "xc"])
77
        else:
78
            xi_t, xf_t, xo_t, xc_t = 0, 0, 0, 0
79
80
        # Add second input
81
        if "xi2" in sequence_map:
82
            xi2, xf2, xo2, xc2 = map(sequence_map.get, ["xi2", "xf2", "xo2", "xc2"])
83
            xi_t += xi2
84
            xf_t += xf2
85
            xo_t += xo2
86
            xc_t += xc2
87
        # LSTM core step
88
        i_t = self._inner_act(xi_t + T.dot(h_tm1, self.U_i) + self.b_i)
89
        f_t = self._inner_act(xf_t + T.dot(h_tm1, self.U_f) + self.b_f)
90
        c_t = f_t * c_tm1 + i_t * self._outer_act(xc_t + T.dot(h_tm1, self.U_c) + self.b_c)
91
        o_t = self._inner_act(xo_t + T.dot(h_tm1, self.U_o) + self.b_o)
92
        h_t = o_t * self._outer_act(c_t)
93
        # Apply mask
94
        if "mask" in sequence_map:
95
            mask = sequence_map["mask"].dimshuffle(0, 'x')
96
            h_t = h_t * mask + h_tm1 * (1 - mask)
97
            c_t = c_t * mask + c_tm1 * (1 - mask)
98
        return h_t, c_t
99
100
    def produce_input_sequences(self, x, mask=None, second_input=None):
101
        # Create sequence map
102
        self._sequence_map.clear()
103
        if self._input_type == "sequence":
104
            # Input vars
105
            xi = T.dot(x, self.W_i)
106
            xf = T.dot(x, self.W_f)
107
            xc = T.dot(x, self.W_c)
108
            xo = T.dot(x, self.W_o)
109
            self._sequence_map.update([("xi", xi), ("xf", xf), ("xc", xc), ("xo", xo)])
110
        # Reset state
111
        if self.reset_state_for_input != None:
112
            self._sequence_map["x"] = x
113
        # Add mask
114
        if mask:
115
            self._sequence_map["mask"] = mask
116
        elif self.mask:
117
            self._sequence_map["mask"] = self.mask
118
        # Add second input
119
        if self.second_input and not second_input:
120
            second_input = self.second_input
121
        if second_input:
122
            xi2 = T.dot(second_input, self.W_i2)
123
            xf2 = T.dot(second_input, self.W_f2)
124
            xc2 = T.dot(second_input, self.W_c2)
125
            xo2 = T.dot(second_input, self.W_o2)
126
            self._sequence_map.update([("xi2", xi2), ("xf2", xf2), ("xc2", xc2), ("xo2", xo2)])
127
        return self._sequence_map.values()
128
129
    def produce_initial_states(self, x):
130
        if self.persistent_state:
131
            return self.state_h, self.state_m
132
        else:
133
            h0 = T.alloc(np.cast[FLOATX](0.), x.shape[0], self._hidden_size)
134
            m0 = h0
135
            return h0, m0
136
137
    def output(self, x):
138
        h0, m0 = self.produce_initial_states(x)
139
        if self._input_type == "sequence":
140
            # Move middle dimension to left-most position
141
            # (sequence, batch, value)
142
            x = x.dimshuffle((1,0,2))
143
            sequences = self.produce_input_sequences(x)
144
        else:
145
            h0 = x
146
            sequences = self.produce_input_sequences(None)
147
148
        [hiddens, memories], _ = theano.scan(
149
            self.step,
150
            sequences=sequences,
151
            outputs_info=[h0, m0],
152
            n_steps=self._steps,
153
            go_backwards=self.go_backwards
154
        )
155
156
        # Save persistent state
157
        if self.persistent_state:
158
            self.register_updates((self.state_h, hiddens[-1]))
159
            self.register_updates((self.state_m, memories[-1]))
160
161
        if self._output_type == "one":
162
            return hiddens[-1]
163
        elif self._output_type == "sequence":
164
            return hiddens.dimshuffle((1,0,2))
165
166
167
    def prepare(self):
168
        self._setup_params()
169
        self._setup_functions()
170
171
    def _setup_functions(self):
172
        self._inner_act = build_activation(self._inner_activation)
173
        self._outer_act = build_activation(self._outer_activation)
174
175
    def _setup_params(self):
176
        self.output_dim = self._hidden_size
177
178
        self.W_i = self.create_weight(self.input_dim, self._hidden_size, "wi", initializer=self._outer_init)
179
        self.U_i = self.create_weight(self._hidden_size, self._hidden_size, "ui", initializer=self._inner_init)
180
        self.b_i = self.create_bias(self._hidden_size, "i")
181
182
        self.W_f = self.create_weight(self.input_dim, self._hidden_size, "wf", initializer=self._outer_init)
183
        self.U_f = self.create_weight(self._hidden_size, self._hidden_size, "uf", initializer=self._inner_init)
184
        self.b_f = self.create_bias(self._hidden_size, "f")
185
        if self.forget_bias > 0:
186
            self.b_f.set_value(np.ones((self._hidden_size,), dtype=FLOATX))
187
188
        self.W_c = self.create_weight(self.input_dim, self._hidden_size, "wc", initializer=self._outer_init)
189
        self.U_c = self.create_weight(self._hidden_size, self._hidden_size, "uc", initializer=self._inner_init)
190
        self.b_c = self.create_bias(self._hidden_size, "c")
191
192
        self.W_o = self.create_weight(self.input_dim, self._hidden_size, "wo", initializer=self._outer_init)
193
        self.U_o = self.create_weight(self._hidden_size, self._hidden_size, "uo", initializer=self._inner_init)
194
        self.b_o = self.create_bias(self._hidden_size, suffix="o")
195
196
197
        if self._input_type == "sequence":
198
            self.register_parameters(self.W_i, self.U_i, self.b_i,
199
                                     self.W_c, self.U_c, self.b_c,
200
                                     self.W_f, self.U_f, self.b_f,
201
                                     self.W_o, self.U_o, self.b_o)
202
        else:
203
            self.register_parameters(self.U_i, self.b_i,
204
                                     self.U_c, self.b_c,
205
                                     self.U_f, self.b_f,
206
                                     self.U_o, self.b_o)
207
        # Second input
208
        if self.second_input_size:
209
            self.W_i2 = self.create_weight(self.second_input_size, self._hidden_size, "wi2", initializer=self._outer_init)
210
            self.W_f2 = self.create_weight(self.second_input_size, self._hidden_size, "wf2", initializer=self._outer_init)
211
            self.W_c2 = self.create_weight(self.second_input_size, self._hidden_size, "wc2", initializer=self._outer_init)
212
            self.W_o2 = self.create_weight(self.second_input_size, self._hidden_size, "wo2", initializer=self._outer_init)
213
            self.register_parameters(self.W_i2, self.W_f2, self.W_c2, self.W_o2)
214
215
        # Create persistent state
216
        if self.persistent_state:
217
            self.state_h = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_h")
218
            self.state_m = self.create_matrix(self.batch_size, self._hidden_size, "lstm_state_m")
219
            self.register_free_parameters(self.state_h, self.state_m)
220
        else:
221
            self.state_h = None
222
            self.state_m = None
223